-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add em and draft nll #29
Open
PBrblt
wants to merge
27
commits into
braindatalab:master
Choose a base branch
from
PBrblt:em_estimator
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
27 commits
Select commit
Hold shift + click to select a range
4c25b67
add em and draft nll
PBrblt 330a1ae
correction_of_em
PBrblt affed29
include em and nll to benchmark
PBrblt 5c909b5
Merge branch 'master' into em_estimator
anujanegi 33d7da0
correcting input parameters
PBrblt 0e5b13f
Merge branch 'master' into em_estimator
anujanegi 8291250
correcting the update
PBrblt 0d2af84
correcting NLL
PBrblt ca2b9df
fix nll + add it to benchmark
agramfort d79a01b
Merge branch 'em_estimator' of https://github.com/PBrblt/BSI-Zoo into…
PBrblt f9b0208
Merge branch 'master' into em_estimator
anujanegi b45946c
Merge branch 'master' into em_estimator
anujanegi 9fb10ad
add positivity check
PBrblt 289954b
further security in moments
PBrblt b489fc3
Update estimators.py
PBrblt 3ed19e5
Update estimators.py
PBrblt 00f8fe9
Update estimators.py
PBrblt 10ea458
making the moments method more robust
PBrblt f779512
Update estimators.py
PBrblt 81f2a1a
tuning number of iteration
PBrblt e674e99
Update estimators.py
PBrblt 6f983db
tuning trust
PBrblt 90a24b0
Update estimators.py
PBrblt e2ebe63
Update estimators.py
PBrblt 5b8128a
removing the -0
PBrblt f6aaa14
Update estimators.py
PBrblt 180b401
Update estimators.py
PBrblt File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -671,3 +671,99 @@ def champagne(L, y, cov=1.0, alpha=0.2, max_iter=1000, max_iter_reweighting=10): | |
x[active_set, :] = x_bar | ||
|
||
return x | ||
|
||
|
||
def lemur(L, y, alpha=0.2, max_iter=1000, max_iter_em=100, trust_tresh=0.9): | ||
"""Latent EM Unsupervised Regression based on https://ieeexplore.ieee.org/document/9746697 | ||
|
||
Parameters | ||
---------- | ||
L : array, shape (n_sensors, n_sources) | ||
lead field matrix modeling the forward operator or dictionary matrix | ||
y : array, shape (n_sensors,) | ||
measurement vector, capturing sensor measurements | ||
max_iter : int, optional | ||
The maximum number of inner loop iterations | ||
max_iter_reweighting : int, optional | ||
Maximum number of reweighting steps i.e outer loop iterations | ||
|
||
Returns | ||
------- | ||
x : array, shape (n_sources,) | ||
Parameter vector, e.g., source vector in the context of BSI (x in the cost | ||
function formula). | ||
|
||
References | ||
---------- | ||
XXX | ||
""" | ||
n_sensors, n_sources = L.shape | ||
_, n_times = y.shape | ||
|
||
def perform_moments(mixture): | ||
"""Moments identification method for gaussian mixture.""" | ||
|
||
m2, m4, m6 = np.mean(mixture ** 2), np.mean(mixture ** 4) / 3, np.mean(mixture ** 6) / 15 | ||
|
||
a = m2 ** 2 - m4 | ||
b = m6 - m2 * m4 | ||
c = m4 ** 2 - m2 * m6 | ||
|
||
if a > 0 : | ||
tresh = m2-np.sqrt(a) | ||
else : | ||
tresh = m2 | ||
|
||
disc = b ** 2 - 4 * a * c | ||
if disc<0 : | ||
#print("oops") | ||
disc = 0 | ||
|
||
# there are two roots for sigma_b_2, however the good one must be in the interval [0,m2-sqrt(a)] | ||
|
||
if ( - b /(2 * a) - np.sqrt(disc)/(2 * a) )>=0 and ( - b /(2 * a) - np.sqrt(disc)/(2 * a) )<= tresh : | ||
sigma_b_2 = - b /(2 * a) - np.sqrt(disc)/(2 * a) | ||
elif ( - b /(2 * a) + np.sqrt(disc)/(2 * a) )>=0 and ( - b /(2 * a) + np.sqrt(disc)/(2 * a) )<= tresh : | ||
sigma_b_2 = - b /(2 * a) + np.sqrt(disc)/(2 * a) | ||
else : | ||
sigma_b_2 = 0.99*tresh # worst case scenario | ||
#sigma_b_2 = - b /(2 * a) + max( - np.sqrt(disc)/(2 * a) , np.sqrt(disc)/(2 * a) ) | ||
sigma_x_2 = (m4 - sigma_b_2 ** 2)/(m2 - sigma_b_2) - 2 * sigma_b_2 | ||
p = (m2 - sigma_b_2)/sigma_x_2 | ||
|
||
return (p, sigma_x_2, sigma_b_2) | ||
|
||
def em_step(obs, param): | ||
"""EM update with x as complete data.""" | ||
|
||
rho = (1-param[0])/param[0] | ||
mu = param[1]/(param[1]+param[2]) | ||
|
||
#s_x = param[1]**2 | ||
#s_b = param[2]**2 | ||
|
||
phi_k = 1/( | ||
1 + rho * np.sqrt(param[1]/param[2] + 1) * np.exp( -np.sum(obs**2,axis=1)/2 *mu/param[2] ) | ||
) | ||
|
||
p = np.mean(phi_k) | ||
s_x = mu*param[2] + mu**2/p * np.mean(phi_k*np.mean(obs**2,axis=1) ) | ||
s_b = np.mean(obs**2) - 2 * mu * np.mean(phi_k*np.mean(obs**2,axis=1)) + p*s_x**2 | ||
|
||
X_eap = obs * phi_k[:,None] * s_x / (s_x + s_b) | ||
|
||
return ([p, s_x, abs(s_b)], X_eap,phi_k)#abs as a sanity guideline | ||
|
||
x = L.T@y # initialisation of X | ||
theta_p = [0,0,0] # initialisation of theta | ||
norm = np.linalg.norm([email protected],2) | ||
|
||
for k_inner in range(max_iter): | ||
z = x + L.T@(y - L@x)/norm# Gradient descent | ||
theta = perform_moments(z)# Initialisation of the EM (feel free to find better ones !) | ||
for k_em in range(max_iter_em): | ||
theta, x, phi = em_step(z, theta)# EM updates | ||
x_res = np.zeros(np.shape(x)) | ||
active = phi>trust_tresh | ||
x_res[active,:] = x[active,:] | ||
return x_res |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rtol and atol values are chosen by you @PBrblt