Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stability of sample weighting with sample density #541

Open
veni-vidi-vici-dormivi opened this issue Oct 7, 2024 · 6 comments
Open

Stability of sample weighting with sample density #541

veni-vidi-vici-dormivi opened this issue Oct 7, 2024 · 6 comments

Comments

@veni-vidi-vici-dormivi
Copy link
Collaborator

I took some time to look into @yquilcaille's distrib_cov._get_weights_nll function. The purpose of the function is to estimate the density of predictor samples in the "predictor space", i.e. giving samples that lie in well sampled regions less weight while giving samples in regions that are not well sampled more weight. I'll illustrate with an example:

If we have two predictors, we can plot them against each other and make a 2D histogram, counting the datapoints in each bin. This already gives some kind of density estimation in showing how many samples are in each bin=region in the sample space. To get a continuous density surface we apply a RectangularGridInterpolator and evaluate this at the actual sample points. In this way sample points that fall into the same bin do not necessarily get the same weights but a weight that corresponds better to the actual density. Now below is a visualization of this example:

Code
import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import RegularGridInterpolator
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

# Simulate some 2D predictor data (two predictors)
np.random.seed(42)
n_samples = 200
predictor_1 = np.random.normal(0, 1, n_samples)
predictor_2 = np.random.normal(0, 1, n_samples)

# Combine into a 2D array
data = np.vstack([predictor_1, predictor_2]).T

# Define bins
n_bins_density = 40 # default of _get_weigths_nll function
mn, mx = np.nanmin(data, axis=0), np.nanmax(data, axis=0)
bins = np.linspace(
    (mn - 0.05 * (mx - mn)),
    (mx + 0.05 * (mx - mn)),
    n_bins_density,
)

# Create a multidimensional histogram
gmt_hist, edges = np.histogramdd(sample=data, bins=bins.T)

# Calculate the bin centers for interpolation
gmt_bins_center = [0.5 * (edge[1:] + edge[:-1]) for edge in edges]
interp = RegularGridInterpolator(points=gmt_bins_center, values=gmt_hist, method = "linear", bounds_error=False, fill_value=None)

# Calculate weights as the inverse of the interpolated density
weights = 1 / interp(data)

# Create 3D figure with histogram bars and interpolated wireframe
fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection='3d')

# Create bars for the histogram
bin_centers_x, bin_centers_y = np.meshgrid(gmt_bins_center[0], gmt_bins_center[1])
bar_width_x = np.diff(edges[0])[0] * 0.8  # Width of each bar in x
bar_width_y = np.diff(edges[1])[0] * 0.8  # Width of each bar in y

# Flatten the grid for bar plotting
x_flat = bin_centers_x.ravel()
y_flat = bin_centers_y.ravel()
z_flat = np.zeros_like(x_flat)  # Start the bars at 0 height
height_flat = gmt_hist.ravel()  # Histogram values as height
ax.bar3d(x_flat, y_flat, z_flat, bar_width_x, bar_width_y, height_flat, shade=True, color='pink', alpha=0.6)

# Generate a fine mesh grid for interpolation (for wireframe)
grid_x, grid_y = np.meshgrid(np.linspace(mn[0], mx[0], 40), np.linspace(mn[1], mx[1], n_bins_density))
interp_points = np.vstack([grid_x.ravel(), grid_y.ravel()]).T
grid_z = interp(interp_points).reshape(grid_x.shape)

# Plot the surface of the interpolation
ax.plot_surface(grid_x, grid_y, grid_z + 5, color='cornflowerblue', alpha=0.7, linewidth=0.7)

# Set labels and title
ax.set_xlabel('Predictor 1')
ax.set_ylabel('Predictor 2')
ax.set_zlabel('Density (Histogram & Interpolation)')

# Change the view angle
ax.view_init(elev=18, azim=-45)  # Adjust the elevation and azimuthal angles as needed

# Create a smaller inset for the 2D scatter plot
inset_ax = inset_axes(ax, width="25%", height="22%", loc='upper left', borderpad=3)
sc_inset = inset_ax.scatter(predictor_1, predictor_2, c=weights, cmap='Oranges', alpha = 0.8, s=20)
inset_ax.set_xlabel('Predictor 1', fontsize=8)
inset_ax.set_ylabel('Predictor 2', fontsize=8)
inset_ax.tick_params(axis='both', which='major', labelsize=8)
inset_ax.grid(True)

# Colorbar for weights 
cax = inset_axes(inset_ax, width="5%", height="90%", loc='right', borderpad=-1)
cbar_inset = fig.colorbar(sc_inset, cax=cax, orientation='vertical')
cbar_inset.set_label('Weight', fontsize=8)
cbar_inset.ax.tick_params(labelsize=8)

plt.show()

image

The upper left panel shows two predictors plotted against each other with their respective weight in color. The 3D histogram shows the binned data and the blue elevated surface shows the Interpolated density. Now, we see that some points have rather counter intuitive weights, see the point near (1.8, 1.9) that has a very high weight (so supposedly lies in a region of little density) even though it arguably lies in a region of higher density compared to the sample sear (0.4, 3.98). Why is that?

I provide here a script for an interactive 3D plot in python (can be run independently from the top code)

interactive 3D plot with python plotly
import numpy as np
import plotly.graph_objects as go

# Simulate some 2D predictor data (two predictors)
np.random.seed(42)
n_samples = 500
predictor_1 = np.random.normal(0, 1, n_samples)
predictor_2 = np.random.normal(0, 1, n_samples)

# Combine into a 2D array
data = np.vstack([predictor_1, predictor_2]).T

# Define bins
n_bins_density = 40
mn, mx = np.nanmin(data, axis=0), np.nanmax(data, axis=0)
bins = np.linspace((mn - 0.05 * (mx - mn)), (mx + 0.05 * (mx - mn)), n_bins_density)

# Create a multidimensional histogram
gmt_hist, edges = np.histogramdd(sample=data, bins=bins.T)

# Calculate the bin centers for interpolation
gmt_bins_center = [0.5 * (edge[1:] + edge[:-1]) for edge in edges]

# Create bar data for the histogram
x_flat = gmt_bins_center[0].repeat(len(gmt_bins_center[1]))
y_flat = np.tile(gmt_bins_center[1], len(gmt_bins_center[0]))
z_flat = np.zeros_like(x_flat)  # Start the bars at 0 height
height_flat = gmt_hist.ravel()  # Histogram values as height
bar_width_x = np.diff(edges[0])[0] * 0.8  # Width of each bar in x
bar_width_y = np.diff(edges[1])[0] * 0.8  # Width of each bar in y

# Wireframe grid for interpolation
grid_x, grid_y = np.meshgrid(np.linspace(mn[0], mx[0], 40), np.linspace(mn[1], mx[1], 40))
interp_points = np.vstack([grid_x.ravel(), grid_y.ravel()]).T
from scipy.interpolate import RegularGridInterpolator
interp = RegularGridInterpolator(points=gmt_bins_center, values=gmt_hist)
grid_z = interp(interp_points).reshape(grid_x.shape)

# Calculate weights as the inverse of the interpolated density
weights = 1 / interp(data)


# Wireframe trace (for interpolation)
wireframe_trace = go.Surface(
    x=grid_x,
    y=grid_y,
    z=grid_z,
    opacity=0.5,
    colorscale='Blues',
    showscale=False,
    name='Interpolated Surface'
)

# Scatter plot trace for the data points
scatter_trace = go.Scatter3d(
    x=predictor_1,
    y=predictor_2,
    z=[0] * len(predictor_1),  # Place points at z=0 for 3D scatter plot
    mode='markers',
    marker=dict(
        size=5,
        color=weights,
        colorscale='tealrose',
        colorbar=dict(title='Weight'),
        opacity=0.8
    ),
    name='Data Points'
)

# Layout for 3D plot
layout = go.Layout(
    title='Interactive 3D Histogram with Interpolation and Data Points',
    scene=dict(
        xaxis_title='Predictor 1',
        yaxis_title='Predictor 2',
        zaxis_title='Density (Histogram & Interpolation)',
    ),
    margin=dict(l=0, r=0, b=0, t=50),
    showlegend=True
)

# Create figure
fig = go.Figure(data=[wireframe_trace, scatter_trace], layout=layout)

# Show the plot
fig.show()

When you take a closer look:
3Dweights2

You see that the red sample in the middle of quite a few other samples falls into a region where the interpolation approaches zero, likely because this sample is at the edge of its bin. So it is likely that the weights of samples are pretty sensitive to the choice of the bins. Here a plot of the same example as in the first graph, just with 25 instead of the default 40 bins:

image

Here the weights look more intuitive.

Another issue I want to raise is that the fact that we group the data into rectangular bins can lead to unintuitive weights too, see an example of linearly dependent predictors (which we don't expect, but still):
image

Code
import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import RegularGridInterpolator
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

n = 40
data_pred = {"tas": np.arange(n),
             "tas2": np.append(np.linspace(0, n/2, n-1), 40),}
n_bins_density = 10

# explode data_pred dictionary into a single array for all predictors
tmp = np.array(list(data_pred.values())).T

# assessing limits on each axis
# TODO *nan*min/max should not be necessary bc we already checked for nan values in the data?
mn, mx = np.nanmin(tmp, axis=0), np.nanmax(tmp, axis=0)

bins = np.linspace(
    (mn - 0.05 * (mx - mn)),
    (mx + 0.05 * (mx - mn)),
    n_bins_density,
)

# interpolating over whole region
gmt_hist, edges = np.histogramdd(sample=tmp, bins=bins.T)

gmt_bins_center = [0.5 * (edge[1:] + edge[:-1]) for edge in edges]
interp = RegularGridInterpolator(points=gmt_bins_center, values=gmt_hist, bounds_error=False, fill_value=None)
weights = 1 / interp(tmp)  # inverse of density

fig = plt.figure()
ax = fig.add_subplot(projection='3d')
ax.scatter(data_pred["tas"], data_pred["tas2"], np.zeros(n), label='data')
ax.set_zlim(0, 6)

# Compute the 3D histogram
hist, edges = np.histogramdd((data_pred["tas"], data_pred["tas2"]), bins=(10, 10))

# Get the positions of the bars
xpos, ypos = np.meshgrid(edges[0][:-1] + 0.25, edges[1][:-1] + 0.25, indexing="ij")
xpos = xpos.ravel()
ypos = ypos.ravel()
zpos = 0

# Get the size of the bars
dx = dy = 0.5 * np.ones_like(zpos)
dz = hist.ravel()

# Plot the 3D histogram bars
ax.bar3d(xpos, ypos, zpos, dx, dy, dz, zsort='average', alpha=0.5, color='b', label='3D histogram')


# Create a grid for the interpolation
grid_x, grid_y = np.meshgrid(np.linspace(mn[0], mx[0], n_bins_density), np.linspace(mn[1], mx[1], n_bins_density))
interp_points = np.vstack([grid_x.ravel(), grid_y.ravel()]).T
grid_z = interp(interp_points).reshape(grid_x.shape)

# Plot the wireframe of the interpolation
ax.plot_surface(grid_x, grid_y, grid_z, color='cornflowerblue', alpha=0.7, linewidth=0.7, label='Interpolation')


# Create a smaller inset for the 2D scatter plot
inset_ax = inset_axes(ax, width="25%", height="22%", loc='upper left', borderpad=1)
inset_ax.grid(True)
sc_inset = inset_ax.scatter(data_pred["tas"], data_pred["tas2"], c=weights, cmap='Oranges', s=20)
inset_ax.set_xlabel('tas', fontsize=6)
inset_ax.set_ylabel('tas2', fontsize=6)
inset_ax.tick_params(axis='both', which='major', labelsize=8)

# Colorbar for weights 
cax = inset_axes(inset_ax, width="5%", height="90%", loc='right', borderpad=-1)
cbar_inset = fig.colorbar(sc_inset, cax=cax, orientation='vertical')
cbar_inset.set_label('Weight', fontsize=8)
cbar_inset.ax.tick_params(labelsize=8)

plt.show()

All in all, I think for now we should 1) factor this function out of the init of the distribution class and give the user more control over the choice of the bins and 2) put a disclaimer into the function description that the weights might be sensitive to the bin choice and add a plotting example so the user can easily inspect the weights like I did above. However, for more than 2 predictors it would get quite hard to make a nice visualization...

@mathause
Copy link
Member

mathause commented Oct 7, 2024

Thanks for looking into that. Random unsolicited thought: could we infer the density using a kde? So give each data point a weight inverse to its distance from the target? (Using a gaussian (or gaspari-cohn kernel).) There is probably a nice package for that.

@veni-vidi-vici-dormivi
Copy link
Collaborator Author

Yes indeed, we could:

import numpy as np
from sklearn.neighbors import KernelDensity
import matplotlib.pyplot as plt

np.random.seed(42)
n_samples = 200
predictor_1 = np.random.normal(0, 1, n_samples)
predictor_2 = np.random.normal(0, 1, n_samples)

kde = KernelDensity(bandwidth=0.5, kernel='gaussian')
kde.fit(np.vstack([predictor_1, predictor_2]).T)

# evaluate density on samples
density = np.exp(kde.score_samples(np.vstack([predictor_1, predictor_2]).T))

# Plot
plt.scatter(predictor_1, predictor_2, c=density, cmap='viridis')
plt.colorbar(label='Density')
plt.xlabel('Predictor 1')
plt.ylabel('Predictor 2')
plt.title('2D KDE Density Estimation')
plt.show()

image

However, it seems to be quite a bit slower?

import numpy as np
from sklearn.neighbors import KernelDensity
import matplotlib.pyplot as plt

np.random.seed(42)
n_samples = 200
predictor_1 = np.random.normal(0, 1, n_samples)
predictor_2 = np.random.normal(0, 1, n_samples)

def kde(x, y):

    kde = KernelDensity(bandwidth=0.5, kernel='gaussian')
    kde.fit(np.vstack([predictor_1, predictor_2]).T)

    # evaluate density on samples
    density = np.exp(kde.score_samples(np.vstack([predictor_1, predictor_2]).T))
    weights = 1/density
    return weights

def hist(predictor_1,predictor_2):
    # Combine into a 2D array
    data = np.vstack([predictor_1, predictor_2]).T

    # Define bins (like in the original function)
    n_bins_density = 25 # 25 # 40
    mn, mx = np.nanmin(data, axis=0), np.nanmax(data, axis=0)
    bins = np.linspace(
        (mn - 0.05 * (mx - mn)),
        (mx + 0.05 * (mx - mn)),
        n_bins_density,
    )

    # Create a multidimensional histogram
    gmt_hist, edges = np.histogramdd(sample=data, bins=bins.T)

    # Calculate the bin centers for interpolation
    gmt_bins_center = [0.5 * (edge[1:] + edge[:-1]) for edge in edges]
    interp = RegularGridInterpolator(points=gmt_bins_center, values=gmt_hist, method = "linear", bounds_error=False, fill_value=None)

    # Calculate weights as the inverse of the interpolated density
    weights = 1 / interp(data)
    return weights

%timeit kde(predictor_1, predictor_2)
# > 1.37 ms ± 19.4 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

%timeit hist(predictor_1, predictor_2)
# > 72.7 µs ± 743 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

@mathause
Copy link
Member

mathause commented Oct 7, 2024

I think that would be ok if we do it only once per gridpoint (or maybe even once per model). One potential issue: the scale of the data may not be the same in each dimension - so may have to scale the data or the bandwidth.

But anyways - factoring this out is a good idea in any case.

@veni-vidi-vici-dormivi
Copy link
Collaborator Author

I think once per model should be enough because the predictors are the same for each gridpoint.
Do you mean the scale in the meaning of variance of the distribution or the number of points?
Yes, factoring it out definitely the first step.

@mathause
Copy link
Member

mathause commented Oct 7, 2024

I think in terms of variance; e.g. if we have one predictor that is temperature in K and one that is precip in mm / s or whatnot, which have very different data ranges. (But maybe that's already in there?)

@veni-vidi-vici-dormivi
Copy link
Collaborator Author

veni-vidi-vici-dormivi commented Oct 7, 2024

Also made a visual for the 1D case:
pred = np.array([0,1,1,2,3,3,4,4,4,4])
image

Code
pred = np.array([0,1,1,2,3,3,4,4,4,4])
n_bins_density = 5

# interpolating over whole region
gmt_hist, edges = np.histogramdd(sample=pred, bins=n_bins_density, density=True)

gmt_bins_center = [0.5 * (edge[1:] + edge[:-1]) for edge in edges]

interp = RegularGridInterpolator(points=gmt_bins_center, values=gmt_hist, method="linear", bounds_error=False, fill_value=None)
density_hist = interp(np.unique(pred))

kde = KernelDensity(bandwidth="scott", kernel='gaussian')
kde.fit(pred.reshape(-1, 1))
density_kde = np.exp(kde.score_samples(np.unique(pred).reshape(-1, 1)))

plt.hist(pred, bins=5, density=True) #, width=0.5)
plt.plot(gmt_bins_center[0], gmt_hist, color = "orange", label = "linear interpolation")
plt.scatter(np.unique(pred), interp(np.unique(pred)), color = "red", marker = "x", label = "interpolated density at samples")
plt.scatter(np.unique(pred), density_kde, color = "pink", marker = "x", label = "KDE density at samples")
plt.scatter(edges[0], np.zeros(len(edges[0])), color = "black", marker = "|")
plt.legend()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants