Skip to content

Commit

Permalink
Merge pull request #109 from Gauravu2/Info_theory_enhancement
Browse files Browse the repository at this point in the history
[WIP] Multivariate Info Theory
  • Loading branch information
skim0119 authored Jan 27, 2023
2 parents 02d8328 + 66366bf commit 9734a8a
Showing 1 changed file with 247 additions and 8 deletions.
255 changes: 247 additions & 8 deletions miv/statistics/info_theory.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
__all__ = [
"probability_distribution",
"shannon_entropy",
"block_entropy",
"entropy_rate",
"active_information",
"mutual_information",
"joint_entropy",
"relative_entropy",
"conditional_entropy",
"cross_entropy",
"transfer_entropy",
"pid",
]


Expand All @@ -27,6 +31,50 @@
from miv.typing import SpikestampsType


def probability_distribution(
spiketrains: SpikestampsType,
channel: float,
t_start: float,
t_end: float,
bin_size: float,
):

"""
Forms the probability distribution required to compute the information theory measures. Probability is computed based on the binned spiketrain generated for the specified bin size.
Parameters
----------
spiketrains : SpikestampsType
Single spike-stamps
channel : float
electrode/channel
t_start : float
Binning start time
t_end : float
Binning end time
bin_size : float
bin size in seconds
Returns
-------
probability_distribution: np.ndarray
probability distribution for the provided spiketrain
"""

assert t_start < t_end, "start time cannot be equal or greater than end time"
assert bin_size > 0, "bin_size should be a finite positive value"

bin_spike = binned_spiketrain(spiketrains, channel, t_start, t_end, bin_size)
prob_spike = np.sum(bin_spike) / np.size(bin_spike)
prob_no_spike = 1 - prob_spike
probability_distribution = bin_spike.copy()
probability_distribution[probability_distribution == 1] = prob_spike
probability_distribution[probability_distribution == 0] = prob_no_spike

return probability_distribution


def shannon_entropy(
spiketrains: SpikestampsType,
channel: float,
Expand All @@ -35,7 +83,7 @@ def shannon_entropy(
bin_size: float,
):
"""
Estimates the shannon entropy for a single channel recording using the binned spiketrain
Estimates the shannon entropy for a single channel recording using the binned spiketrain.
Parameters
----------
Expand All @@ -56,10 +104,14 @@ def shannon_entropy(
Shannon entropy for the given channel
"""

assert t_start < t_end, "start time cannot be equal or greater than end time"
assert bin_size > 0, "bin_size should be a finite positive value"
bin_spike = binned_spiketrain(spiketrains, channel, t_start, t_end, bin_size)
spike_dist = pyinform.dist.Dist(bin_spike)
shannon_entropy = pyinform.shannon.entropy(spike_dist)
prob_spike = np.sum(bin_spike) / np.size(bin_spike)
prob_no_spike = 1 - prob_spike
shannon_entropy = -(
prob_spike * np.log2(prob_spike) + prob_no_spike * np.log2(prob_no_spike)
)
return shannon_entropy


Expand Down Expand Up @@ -95,7 +147,9 @@ def block_entropy(
Block entropy for the given channel
"""

assert t_start < t_end, "start time cannot be equal or greater than end time"
assert bin_size > 0, "bin_size should be a finite positive value"
assert his > 0, "history length should be a finite positive value"
bin_spike = binned_spiketrain(spiketrains, channel, t_start, t_end, bin_size)
block_entropy = pyinform.blockentropy.block_entropy(bin_spike, his)
return block_entropy
Expand Down Expand Up @@ -133,7 +187,9 @@ def entropy_rate(
entropy rate for the given channel
"""

assert t_start < t_end, "start time cannot be equal or greater than end time"
assert bin_size > 0, "bin_size should be a finite positive value"
assert his > 0, "history length should be a finite positive value"
bin_spike = binned_spiketrain(spiketrains, channel, t_start, t_end, bin_size)
entropy_rate = pyinform.entropyrate.entropy_rate(bin_spike, k=his)
return entropy_rate
Expand Down Expand Up @@ -171,7 +227,9 @@ def active_information(
Active information for the given channel
"""

assert t_start < t_end, "start time cannot be equal or greater than end time"
assert bin_size > 0, "bin_size should be a finite positive value"
assert his > 0, "history length should be a finite positive value"
bin_spike = binned_spiketrain(spiketrains, channel, t_start, t_end, bin_size)
active_information = pyinform.activeinfo.active_info(bin_spike, his)
return active_information
Expand Down Expand Up @@ -209,13 +267,60 @@ def mutual_information(
Mutual information for the given pair of electrodes
"""

assert t_start < t_end, "start time cannot be equal or greater than end time"
assert bin_size > 0, "bin_size should be a finite positive value"
bin_spike_x = binned_spiketrain(spiketrains, channelx, t_start, t_end, bin_size)
bin_spike_y = binned_spiketrain(spiketrains, channely, t_start, t_end, bin_size)
mutual_information = pyinform.mutualinfo.mutual_info(bin_spike_x, bin_spike_y)
return mutual_information


def joint_entropy(
spiketrains: SpikestampsType,
channelx: float,
channely: float,
t_start: float,
t_end: float,
bin_size: float,
):
"""
Estimates the joint entropy for the pair of electorde recordings (X & Y) using the binned spiketrains
Parameters
----------
spiketrains : SpikestampsType
Single spike-stamps
channelx : float
electrode/channel X
channely : float
electrode/channel Y
t_start : float
Binning start time
t_end : float
Binning end time
bin_size : float
bin size in seconds
Returns
-------
joint_entropy: float
joint entropy for the given pair of electrodes
"""
assert t_start < t_end, "start time cannot be equal or greater than end time"
assert bin_size > 0, "bin_size should be a finite positive value"

spike_dist_x = probability_distribution(
spiketrains, channelx, t_start, t_end, bin_size
)
spike_dist_y = probability_distribution(
spiketrains, channely, t_start, t_end, bin_size
)
spike_dist_xy = np.logical_and(spike_dist_x, spike_dist_y)
joint_entropy = -np.sum(spike_dist_xy * np.log2(spike_dist_xy))
return joint_entropy


def relative_entropy(
spiketrains: SpikestampsType,
channelx: float,
Expand Down Expand Up @@ -248,6 +353,8 @@ def relative_entropy(
Relative_entropy for the given pair of electrodes
"""
assert t_start < t_end, "start time cannot be equal or greater than end time"
assert bin_size > 0, "bin_size should be a finite positive value"

bin_spike_x = binned_spiketrain(spiketrains, channelx, t_start, t_end, bin_size)
bin_spike_y = binned_spiketrain(spiketrains, channely, t_start, t_end, bin_size)
Expand Down Expand Up @@ -289,6 +396,8 @@ def conditional_entropy(
conditional entropy for the given pair of electrodes
"""
assert t_start < t_end, "start time cannot be equal or greater than end time"
assert bin_size > 0, "bin_size should be a finite positive value"

bin_spike_x = binned_spiketrain(spiketrains, channelx, t_start, t_end, bin_size)
bin_spike_y = binned_spiketrain(spiketrains, channely, t_start, t_end, bin_size)
Expand All @@ -298,6 +407,51 @@ def conditional_entropy(
return conditional_entropy


def cross_entropy(
spiketrains: SpikestampsType,
channelx: float,
channely: float,
t_start: float,
t_end: float,
bin_size: float,
):
"""
Estimates the cross entropy for the pair of electorde recordings (X & Y) using the binned spiketrains
Parameters
----------
spiketrains : SpikestampsType
Single spike-stamps
channelx : float
electrode/channel X
channely : float
electrode/channel Y
t_start : float
Binning start time
t_end : float
Binning end time
bin_size : float
bin size in seconds
Returns
-------
cross_entropy: float
cross entropy for the given pair of electrodes
"""
assert t_start < t_end, "start time cannot be equal or greater than end time"
assert bin_size > 0, "bin_size should be a finite positive value"

spike_dist_x = probability_distribution(
spiketrains, channelx, t_start, t_end, bin_size
)
spike_dist_y = probability_distribution(
spiketrains, channely, t_start, t_end, bin_size
)
cross_entropy = -np.sum(spike_dist_x * np.log2(spike_dist_y))
return cross_entropy


def transfer_entropy(
spiketrains: SpikestampsType,
channelx: float,
Expand Down Expand Up @@ -333,10 +487,95 @@ def transfer_entropy(
Transfer_entropy for the given pair of electrodes
"""
assert t_start < t_end, "start time cannot be equal or greater than end time"
assert bin_size > 0, "bin_size should be a finite positive value"
assert his > 0, "history length should be a finite positive value"

bin_spike_x = binned_spiketrain(spiketrains, channelx, t_start, t_end, bin_size)
bin_spike_y = binned_spiketrain(spiketrains, channely, t_start, t_end, bin_size)
transfer_entropy = pyinform.transferentropy.transfer_entropy(
bin_spike_x, bin_spike_y, k=his
)
return transfer_entropy


def pid(
spiketrains: SpikestampsType,
channelx: float,
channely: float,
channelz: float,
t_start: float,
t_end: float,
bin_size: float,
):
"""
Decomposes the information provided by channel x and y about channel z in redundancy, unique information, and synergy.
Parameters
----------
spiketrains : SpikestampsType
Single spike-stamps
channelx : float
electrode/channel X
channely : float
electrode/channel Y
channelz : float
electrode/channel Y
t_start : float
Binning start time
t_end : float
Binning end time
bin_size : float
bin size in seconds
Returns
-------
redundancy: float
redundant information provided by both x and y about z
unique_information_x: float
information uniquely provided by x about z
unique_information_y: float
information uniquely provided by y about z
synergy: float
synergetic information provided by both x and y about z
total_information: float
total information provided by both x and y about z
"""
assert t_start < t_end, "start time cannot be equal or greater than end time"
assert bin_size > 0, "bin_size should be a finite positive value"

spike_dist_x = probability_distribution(
spiketrains, channelx, t_start, t_end, bin_size
)
spike_dist_y = probability_distribution(
spiketrains, channely, t_start, t_end, bin_size
)
spike_dist_z = probability_distribution(
spiketrains, channely, t_start, t_end, bin_size
)
spike_dist_xz = np.logical_and(spike_dist_x, spike_dist_z)
spike_dist_yz = np.logical_and(spike_dist_y, spike_dist_z)
spike_dist_xy = np.logical_and(spike_dist_x, spike_dist_y)
spike_dist_xyz = np.logical_and(spike_dist_x, spike_dist_y, spike_dist_z)
I_x_z = np.sum(
spike_dist_xz * np.log2(spike_dist_xz / (spike_dist_x * spike_dist_z))
)
I_y_z = np.sum(
spike_dist_yz * np.log2(spike_dist_yz / (spike_dist_y * spike_dist_z))
)
redundancy = np.min(I_x_z, I_y_z)
unique_information_x = I_x_z - redundancy
unique_information_y = I_y_z - redundancy
total_information = np.sum(
spike_dist_xyz * np.log2(spike_dist_xyz / (spike_dist_xy * spike_dist_z))
)
synergy = (
total_information - redundancy - unique_information_x - unique_information_y
)
return (
redundancy,
unique_information_x,
unique_information_y,
synergy,
total_information,
)

0 comments on commit 9734a8a

Please sign in to comment.