Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Multivariate Info Theory #109

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
255 changes: 247 additions & 8 deletions miv/statistics/info_theory.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
__all__ = [
"probability_distribution",
"shannon_entropy",
"block_entropy",
"entropy_rate",
"active_information",
"mutual_information",
"joint_entropy",
"relative_entropy",
"conditional_entropy",
"cross_entropy",
"transfer_entropy",
"pid",
]


Expand All @@ -27,6 +31,50 @@
from miv.typing import SpikestampsType


def probability_distribution(
spiketrains: SpikestampsType,
channel: float,
t_start: float,
t_end: float,
bin_size: float,
):

"""
Forms the probability distribution required to compute the information theory measures. Probability is computed based on the binned spiketrain generated for the specified bin size.

Parameters
----------
spiketrains : SpikestampsType
Single spike-stamps
channel : float
electrode/channel
t_start : float
Binning start time
t_end : float
Binning end time
bin_size : float
bin size in seconds

Returns
-------
probability_distribution: np.ndarray
probability distribution for the provided spiketrain

"""

assert t_start < t_end, "start time cannot be equal or greater than end time"
assert bin_size > 0, "bin_size should be a finite positive value"

bin_spike = binned_spiketrain(spiketrains, channel, t_start, t_end, bin_size)
prob_spike = np.sum(bin_spike) / np.size(bin_spike)
prob_no_spike = 1 - prob_spike
probability_distribution = bin_spike.copy()
probability_distribution[probability_distribution == 1] = prob_spike
probability_distribution[probability_distribution == 0] = prob_no_spike

return probability_distribution


def shannon_entropy(
spiketrains: SpikestampsType,
channel: float,
Expand All @@ -35,7 +83,7 @@ def shannon_entropy(
bin_size: float,
):
"""
Estimates the shannon entropy for a single channel recording using the binned spiketrain
Estimates the shannon entropy for a single channel recording using the binned spiketrain.

Parameters
----------
Expand All @@ -56,10 +104,14 @@ def shannon_entropy(
Shannon entropy for the given channel

"""

assert t_start < t_end, "start time cannot be equal or greater than end time"
assert bin_size > 0, "bin_size should be a finite positive value"
bin_spike = binned_spiketrain(spiketrains, channel, t_start, t_end, bin_size)
spike_dist = pyinform.dist.Dist(bin_spike)
shannon_entropy = pyinform.shannon.entropy(spike_dist)
prob_spike = np.sum(bin_spike) / np.size(bin_spike)
prob_no_spike = 1 - prob_spike
shannon_entropy = -(
prob_spike * np.log2(prob_spike) + prob_no_spike * np.log2(prob_no_spike)
)
return shannon_entropy


Expand Down Expand Up @@ -95,7 +147,9 @@ def block_entropy(
Block entropy for the given channel

"""

assert t_start < t_end, "start time cannot be equal or greater than end time"
assert bin_size > 0, "bin_size should be a finite positive value"
assert his > 0, "history length should be a finite positive value"
bin_spike = binned_spiketrain(spiketrains, channel, t_start, t_end, bin_size)
block_entropy = pyinform.blockentropy.block_entropy(bin_spike, his)
return block_entropy
Expand Down Expand Up @@ -133,7 +187,9 @@ def entropy_rate(
entropy rate for the given channel

"""

assert t_start < t_end, "start time cannot be equal or greater than end time"
assert bin_size > 0, "bin_size should be a finite positive value"
assert his > 0, "history length should be a finite positive value"
bin_spike = binned_spiketrain(spiketrains, channel, t_start, t_end, bin_size)
entropy_rate = pyinform.entropyrate.entropy_rate(bin_spike, k=his)
return entropy_rate
Expand Down Expand Up @@ -171,7 +227,9 @@ def active_information(
Active information for the given channel

"""

assert t_start < t_end, "start time cannot be equal or greater than end time"
assert bin_size > 0, "bin_size should be a finite positive value"
assert his > 0, "history length should be a finite positive value"
bin_spike = binned_spiketrain(spiketrains, channel, t_start, t_end, bin_size)
active_information = pyinform.activeinfo.active_info(bin_spike, his)
return active_information
Expand Down Expand Up @@ -209,13 +267,60 @@ def mutual_information(
Mutual information for the given pair of electrodes

"""

assert t_start < t_end, "start time cannot be equal or greater than end time"
assert bin_size > 0, "bin_size should be a finite positive value"
bin_spike_x = binned_spiketrain(spiketrains, channelx, t_start, t_end, bin_size)
bin_spike_y = binned_spiketrain(spiketrains, channely, t_start, t_end, bin_size)
mutual_information = pyinform.mutualinfo.mutual_info(bin_spike_x, bin_spike_y)
return mutual_information


def joint_entropy(
spiketrains: SpikestampsType,
channelx: float,
channely: float,
t_start: float,
t_end: float,
bin_size: float,
):
"""
Estimates the joint entropy for the pair of electorde recordings (X & Y) using the binned spiketrains

Parameters
----------
spiketrains : SpikestampsType
Single spike-stamps
channelx : float
electrode/channel X
channely : float
electrode/channel Y
t_start : float
Binning start time
t_end : float
Binning end time
bin_size : float
bin size in seconds

Returns
-------
joint_entropy: float
joint entropy for the given pair of electrodes

"""
assert t_start < t_end, "start time cannot be equal or greater than end time"
assert bin_size > 0, "bin_size should be a finite positive value"

spike_dist_x = probability_distribution(
spiketrains, channelx, t_start, t_end, bin_size
)
spike_dist_y = probability_distribution(
spiketrains, channely, t_start, t_end, bin_size
)
spike_dist_xy = np.logical_and(spike_dist_x, spike_dist_y)
joint_entropy = -np.sum(spike_dist_xy * np.log2(spike_dist_xy))
return joint_entropy


def relative_entropy(
spiketrains: SpikestampsType,
channelx: float,
Expand Down Expand Up @@ -248,6 +353,8 @@ def relative_entropy(
Relative_entropy for the given pair of electrodes

"""
assert t_start < t_end, "start time cannot be equal or greater than end time"
assert bin_size > 0, "bin_size should be a finite positive value"

bin_spike_x = binned_spiketrain(spiketrains, channelx, t_start, t_end, bin_size)
bin_spike_y = binned_spiketrain(spiketrains, channely, t_start, t_end, bin_size)
Expand Down Expand Up @@ -289,6 +396,8 @@ def conditional_entropy(
conditional entropy for the given pair of electrodes

"""
assert t_start < t_end, "start time cannot be equal or greater than end time"
assert bin_size > 0, "bin_size should be a finite positive value"

bin_spike_x = binned_spiketrain(spiketrains, channelx, t_start, t_end, bin_size)
bin_spike_y = binned_spiketrain(spiketrains, channely, t_start, t_end, bin_size)
Expand All @@ -298,6 +407,51 @@ def conditional_entropy(
return conditional_entropy


def cross_entropy(
spiketrains: SpikestampsType,
channelx: float,
channely: float,
t_start: float,
t_end: float,
bin_size: float,
):
"""
Estimates the cross entropy for the pair of electorde recordings (X & Y) using the binned spiketrains

Parameters
----------
spiketrains : SpikestampsType
Single spike-stamps
channelx : float
electrode/channel X
channely : float
electrode/channel Y
t_start : float
Binning start time
t_end : float
Binning end time
bin_size : float
bin size in seconds

Returns
-------
cross_entropy: float
cross entropy for the given pair of electrodes

"""
assert t_start < t_end, "start time cannot be equal or greater than end time"
assert bin_size > 0, "bin_size should be a finite positive value"

spike_dist_x = probability_distribution(
spiketrains, channelx, t_start, t_end, bin_size
)
spike_dist_y = probability_distribution(
spiketrains, channely, t_start, t_end, bin_size
)
cross_entropy = -np.sum(spike_dist_x * np.log2(spike_dist_y))
return cross_entropy


def transfer_entropy(
spiketrains: SpikestampsType,
channelx: float,
Expand Down Expand Up @@ -333,10 +487,95 @@ def transfer_entropy(
Transfer_entropy for the given pair of electrodes

"""
assert t_start < t_end, "start time cannot be equal or greater than end time"
assert bin_size > 0, "bin_size should be a finite positive value"
assert his > 0, "history length should be a finite positive value"

bin_spike_x = binned_spiketrain(spiketrains, channelx, t_start, t_end, bin_size)
bin_spike_y = binned_spiketrain(spiketrains, channely, t_start, t_end, bin_size)
transfer_entropy = pyinform.transferentropy.transfer_entropy(
bin_spike_x, bin_spike_y, k=his
)
return transfer_entropy


def pid(
spiketrains: SpikestampsType,
channelx: float,
channely: float,
channelz: float,
t_start: float,
t_end: float,
bin_size: float,
):
"""
Decomposes the information provided by channel x and y about channel z in redundancy, unique information, and synergy.

Parameters
----------
spiketrains : SpikestampsType
Single spike-stamps
channelx : float
electrode/channel X
channely : float
electrode/channel Y
channelz : float
electrode/channel Y
t_start : float
Binning start time
t_end : float
Binning end time
bin_size : float
bin size in seconds

Returns
-------
redundancy: float
redundant information provided by both x and y about z
unique_information_x: float
information uniquely provided by x about z
unique_information_y: float
information uniquely provided by y about z
synergy: float
synergetic information provided by both x and y about z
total_information: float
total information provided by both x and y about z
"""
assert t_start < t_end, "start time cannot be equal or greater than end time"
assert bin_size > 0, "bin_size should be a finite positive value"

spike_dist_x = probability_distribution(
spiketrains, channelx, t_start, t_end, bin_size
)
spike_dist_y = probability_distribution(
spiketrains, channely, t_start, t_end, bin_size
)
spike_dist_z = probability_distribution(
spiketrains, channely, t_start, t_end, bin_size
)
spike_dist_xz = np.logical_and(spike_dist_x, spike_dist_z)
spike_dist_yz = np.logical_and(spike_dist_y, spike_dist_z)
spike_dist_xy = np.logical_and(spike_dist_x, spike_dist_y)
spike_dist_xyz = np.logical_and(spike_dist_x, spike_dist_y, spike_dist_z)
I_x_z = np.sum(
spike_dist_xz * np.log2(spike_dist_xz / (spike_dist_x * spike_dist_z))
)
I_y_z = np.sum(
spike_dist_yz * np.log2(spike_dist_yz / (spike_dist_y * spike_dist_z))
)
redundancy = np.min(I_x_z, I_y_z)
unique_information_x = I_x_z - redundancy
unique_information_y = I_y_z - redundancy
total_information = np.sum(
spike_dist_xyz * np.log2(spike_dist_xyz / (spike_dist_xy * spike_dist_z))
)
synergy = (
total_information - redundancy - unique_information_x - unique_information_y
)
return (
redundancy,
unique_information_x,
unique_information_y,
synergy,
total_information,
)