Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Feature abnormality detection #98

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
146 commits
Select commit Hold shift + click to select a range
141ed78
add SpikeCutout class
jihugo Jul 1, 2022
9ab866c
fix Spikecutout initializer
jihugo Jul 1, 2022
5337f61
Merge branch 'GazzolaLab:main' into feature_abnormality_detection
jihugo Jul 1, 2022
b5ac936
clean up redundant stuff
jihugo Jul 1, 2022
4cd97dc
:construction: Abnormality Detector
jihugo Jul 1, 2022
ce2d95c
:construction: add ChannelSpikeCutout class
jihugo Jul 1, 2022
2cae6ce
:construction: edit _get_cutouts()
jihugo Jul 1, 2022
807ac2b
Merge branch 'feature_abnormality_detection' of https://github.com/ji…
jihugo Jul 1, 2022
808fe56
:construction: fix type error
jihugo Jul 1, 2022
d07c23c
:construction: add documentation
jihugo Jul 1, 2022
063b384
:construction: update inline doc
jihugo Jul 1, 2022
6c51e0b
:construction: add categorize_spontaneous() and inline doc
jihugo Jul 1, 2022
f992937
:construction: add categorize()
jihugo Jul 1, 2022
5cb97d6
:construction: add get_labeled_cutouts()
jihugo Jul 1, 2022
b65811e
:construction: add size into return value for get_labeled_cutouts()
jihugo Jul 1, 2022
f205928
:construction: add train_model()
jihugo Jul 1, 2022
39f4d07
:construction: correct typing error
jihugo Jul 1, 2022
b8de9a4
:construction: add get_cutouts_by_component() and inline doc
jihugo Jul 2, 2022
695e238
:construction: add dunder functions
jihugo Jul 2, 2022
b7bf194
add plot_waveforms_with_SpikeCutout(...)
jihugo Jul 2, 2022
c554c6b
:construction: fix typing errors
jihugo Jul 2, 2022
0e64a9c
:construction: fix typing errors
jihugo Jul 2, 2022
c7c3799
include new function in __all__
jihugo Jul 2, 2022
88dfc68
:construction: get rid of misused dunder function
jihugo Jul 2, 2022
216d7df
fix n_spikes in new function
jihugo Jul 2, 2022
c2547cc
:construction: add num_channels variable
jihugo Jul 2, 2022
3e13c43
:construction: include empty channels in _get_cutouts(...)
jihugo Jul 4, 2022
1b93b2a
:construction: make skipped channels a class variable
jihugo Jul 4, 2022
c8ad5bd
:construction: fix categorization bug and typos
jihugo Jul 4, 2022
8b5e3ef
:construction: fix error with getting labeled cutouts
jihugo Jul 6, 2022
a62ea2c
:construction: change category numbers to have -1 to reflect 'uncateg…
jihugo Jul 6, 2022
8dd7e0c
:construction: fix some errors in train_model(...)
jihugo Jul 6, 2022
f390325
:construction: minor fixes
jihugo Jul 6, 2022
20d5326
:construction: change categorized bool from channel level to spike le…
jihugo Jul 6, 2022
516ac47
:construction: add get_only_neuronal_spikes(...)
jihugo Jul 6, 2022
cee19fb
:construction: fix typing errors
jihugo Jul 6, 2022
de09446
:construction: fix some errors with get_only_neuronal_spikes(...)
jihugo Jul 7, 2022
9a002a9
:construction: add time to SpikeCutout object as variable
jihugo Jul 7, 2022
4473917
:construction: add MockSpikeCutout object
jihugo Jul 7, 2022
bbb0fc7
:construction: test_get_cutouts_by_components
jihugo Jul 8, 2022
fb3c232
:construction: fix: make categorize() change self.categorized to True
jihugo Jul 8, 2022
7862a36
:construction: add test_categorize() and test_labeled_cutouts()
jihugo Jul 8, 2022
57eb8f2
add test_len()
jihugo Jul 8, 2022
6ca4559
:construction: add MockAbnormalityDetector and test_categorize_sponta…
jihugo Jul 8, 2022
296a8c7
:construction: add inline doc and start get_only_neuronal_components
jihugo Jul 10, 2022
147ffcb
:construction: add get_only_neuronal_components
jihugo Jul 10, 2022
0199080
:construction: fix typing errors
jihugo Jul 10, 2022
2f8d7c3
:construction: add test_train_model(), found error with train_model()
jihugo Jul 10, 2022
f7d5f39
:construction: add check for model is trained for getting new spiketr…
jihugo Jul 10, 2022
25936ef
add KerasModelType
jihugo Jul 14, 2022
1a4dc12
:construction: use SpikeFeatureExtractionProtocol instead of PCADecom…
jihugo Jul 14, 2022
bdc1f8c
add kerasModelType
jihugo Jul 14, 2022
c51f8b0
:construction: use spike_feature_extractor instead of specific PCADec…
jihugo Jul 14, 2022
9a09b7c
:construction: separate evaluate_model(..) from train_model(..)
jihugo Jul 14, 2022
833ff10
get rid of KerasModelType
jihugo Jul 15, 2022
6fbef2f
:construction: temporarily remove test_train_model()
jihugo Jul 15, 2022
0552199
:construction: add SpikeClassificationModelProtocol
jihugo Jul 15, 2022
a739312
:construction: add compile(..)
jihugo Jul 15, 2022
aa48001
add project(..) to SpikeFeatureExtractionProtocol
jihugo Jul 15, 2022
483bb4b
:construction: fix some errors
jihugo Jul 15, 2022
ca244c0
:construction: update code after changes in AbnormalityDetector
jihugo Jul 15, 2022
204e7eb
:construction: add test_create_default_model()
jihugo Jul 15, 2022
7abe6d4
:construction: add test_create_default_model()
jihugo Jul 15, 2022
cd68af5
:contruction: add test_evaluate_model()
jihugo Jul 18, 2022
8186a88
:construction: fix code error
jihugo Jul 18, 2022
e2dae65
:construction: fix some errors
jihugo Jul 18, 2022
94f04fd
fix issues with categorization
jihugo Jul 19, 2022
d010456
add neuronal_spike_classification
jihugo Jul 21, 2022
3621f6b
add inline doc
jihugo Jul 21, 2022
b75923a
add import neuronal_spike_classification
jihugo Jul 21, 2022
813b5eb
name change from 'events' to 'classification'
jihugo Jul 21, 2022
261c8c4
rewrite __init__ and _get_all_cutouts
jihugo Jul 21, 2022
601d918
change categorize_spontaneous to use np.ndarray
jihugo Jul 21, 2022
7fab1b5
make get_labeled_cutouts return labels and cutouts in np.ndarray inst…
jihugo Jul 21, 2022
220666d
minor name change for get_labeled_cutouts()['cutouts']
jihugo Jul 21, 2022
6174131
minor name change for get_labeled_cutouts()['cutouts']
jihugo Jul 21, 2022
be0344b
rewrite train_model into init_classifier and train_classifier_model t…
jihugo Jul 21, 2022
8607fab
name change
jihugo Jul 21, 2022
efbb11f
get rid of create default model function
jihugo Jul 21, 2022
64398ed
get rid of old __init__
jihugo Jul 21, 2022
a12a05c
organize directory
jihugo Jul 21, 2022
c7f69bc
get rid of redundant test cases
jihugo Jul 21, 2022
85885e5
add _check_categorized()
jihugo Jul 21, 2022
e2962de
minor fixes
jihugo Jul 21, 2022
2b86de2
change import directory after organization
jihugo Jul 22, 2022
994170e
update key name for get_labeled_cutouts return value
jihugo Jul 22, 2022
475e97f
add default_model() test case and fix some errors
jihugo Jul 22, 2022
12ab700
fix typing errors
jihugo Jul 22, 2022
416758b
fix typing errors
jihugo Jul 22, 2022
5aa081f
add AdvancedMockData
jihugo Jul 22, 2022
d720b17
fix errors with **kwargs
jihugo Jul 24, 2022
ccf9b52
reorganize and add notebook, will run later since it may take some hours
jihugo Jul 26, 2022
f8ae6bd
add train data file
jihugo Jul 26, 2022
cf61639
reveal best combination
jihugo Jul 26, 2022
2ec1cc4
change default compile parameters after result from comparison
jihugo Jul 26, 2022
39a8a85
fix typo
jihugo Jul 26, 2022
dee7dc1
Add EarlyStopping callback to fit function and update best combination
jihugo Jul 26, 2022
39b6b63
add default_init_and_train_model function
jihugo Jul 27, 2022
b595eb7
Add EarlyStopping callback and change loss to BinaryCrossentropy
jihugo Jul 27, 2022
1441c73
:construction: add test_predict_categories()
jihugo Jul 27, 2022
7451f57
correct output layer
jihugo Jul 29, 2022
fec8700
match new function argument format for create_default_model
jihugo Jul 29, 2022
61a4f5c
change prediction function and test to use sigmoid output
jihugo Jul 29, 2022
ee6a15d
add test_confusion_matrix and fix error in get_confusion_matrix
jihugo Jul 29, 2022
2781a06
Merge branch 'main' of https://github.com/GazzolaLab/MiV-OS into feat…
jihugo Jul 29, 2022
5dc9c47
formatting
jihugo Jul 29, 2022
e65021d
fix AdvancedMockData
jihugo Jul 29, 2022
78dbeea
fix AdvancedMockData error with signal shape
jihugo Jul 29, 2022
6280060
add test_get_all_cutouts and edit typing
jihugo Jul 29, 2022
9261907
add test_categorize_spontaneous and fix error with categorize_spontan…
jihugo Jul 30, 2022
617d01b
fix error with train_model()
jihugo Aug 3, 2022
158f01d
fix typing errors
jihugo Aug 3, 2022
3bcd6a5
add evaluate_model()
jihugo Aug 3, 2022
22118f1
check classifier
jihugo Aug 3, 2022
d72426c
add compile functions
jihugo Aug 3, 2022
cbaa912
:construction: add test case and fix errors with evaluate_model and o…
jihugo Aug 3, 2022
8592247
Delete redundant imports and fix error with _get_all_cutouts()
jihugo Aug 6, 2022
f457132
make get_ctuouts_by_component() return raw spikes instead of SpikeCut…
jihugo Aug 6, 2022
9c97796
add get_spontaneous_cutouts_by_component()
jihugo Aug 6, 2022
49170a2
update test cases
jihugo Aug 11, 2022
cd0f620
rename and restructure for new class of detector
jihugo Aug 11, 2022
ef461ff
add DetectorWithTrainData
jihugo Aug 11, 2022
ebcdd2d
:construction: add new functions
jihugo Aug 11, 2022
5d8bbba
add default_compile_and_train()
jihugo Aug 11, 2022
6d72a54
add keep_only_neuronal_spikes()
jihugo Aug 11, 2022
e74518e
fix typing errors
jihugo Aug 11, 2022
4c4942a
fix conflicts
jihugo Aug 11, 2022
60802f6
add classifer demo
jihugo Aug 11, 2022
4a16ce4
merge from head
jihugo Aug 11, 2022
0fd9272
resolve conflicts
jihugo Aug 11, 2022
6ff0b68
add keep_only_neuronal_spikes_from_data
jihugo Aug 17, 2022
c0888f7
change return type for keep_only_neuronal_spikes_from_data
jihugo Aug 17, 2022
470b279
fix some typing errors
jihugo Aug 17, 2022
f6fde5e
:construction: add train_model()
jihugo Aug 21, 2022
84a28a7
:construction: fix keep_only_neuronal_spikes_from_data()
jihugo Aug 22, 2022
92f759a
:construction: add pre and post values for extract_waveforms in keep_…
jihugo Aug 22, 2022
d43e4ba
:construction: change keep_only_neuronal_spikes_from_data to use the …
jihugo Aug 22, 2022
b35d2cf
fix some errors
jihugo Aug 22, 2022
ee21055
add auto_masking_with_classification_demo notebook
jihugo Aug 22, 2022
74d2007
Merge branch 'GazzolaLab:main' into feature_abnormality_detection
jihugo Aug 22, 2022
616ca3f
fix merge conflict
jihugo Aug 22, 2022
1b3bff9
Merge branch 'feature_abnormality_detection' of https://github.com/ji…
jihugo Aug 22, 2022
c0d977f
added some sentence
jihugo Aug 22, 2022
9f65111
Merge branch 'update-0.2.3' into feature_abnormality_detection
skim0119 Dec 20, 2022
cd7dda6
feat: add window view
skim0119 Feb 9, 2023
68bada3
Merge remote-tracking branch 'public/update-0.2.4' into feature_abnor…
skim0119 Feb 9, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### This notebook seeks the best combination of Keras model optimizer and loss used in the fit function.\n",
"__Accuracy is the goal__"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import tensorflow as tf\n",
"from tqdm import tqdm\n",
"\n",
"datapath = \"./train_data_00.npz\""
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"with np.load(datapath) as file:\n",
" labels = file['label']\n",
" spikes = file['spike']\n",
"\n",
"from sklearn.utils import shuffle\n",
"labels, spikes = shuffle(labels, spikes)\n",
"\n",
"sample_percentage = 0.4\n",
"cut = int(sample_percentage*len(labels))\n",
"labels = labels[:cut]\n",
"spikes = spikes[:cut]\n",
"\n",
"split = int(len(labels) * 0.8)\n",
"train_labels = labels[:split]\n",
"train_spikes = spikes[:split]\n",
"test_labels = labels[split:]\n",
"test_spikes = spikes[split:]"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"spike_length = np.shape(train_spikes)[1]\n",
"hidden_layer_size = spike_length\n",
"\n",
"layers = [\n",
" tf.keras.layers.Dense(spike_length),\n",
" tf.keras.layers.Dense(hidden_layer_size),\n",
" tf.keras.layers.Dense(len(np.unique(labels)))\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"optimizers = np.array([\n",
" \"Adadelta\",\n",
" \"Adagrad\",\n",
" \"Adam\",\n",
" \"Adamax\",\n",
" \"Ftrl\",\n",
" \"Nadam\",\n",
" \"RMSprop\",\n",
" \"SGD\"\n",
"])\n",
"\n",
"losses = np.array([\n",
" \"BinaryCrossentropy\",\n",
" \"BinaryFocalCrossentropy\",\n",
" # \"CategoricalCrossentropy\",\n",
" \"CategoricalHinge\",\n",
" # \"CosineSimilarity\",\n",
" \"Hinge\",\n",
" \"Huber\",\n",
" \"KLDivergence\",\n",
" \"LogCosh\",\n",
" \"MeanAbsoluteError\",\n",
" \"MeanAbsolutePercentageError\",\n",
" \"MeanSquaredError\",\n",
" \"MeanSquaredLogarithmicError\",\n",
" \"Poisson\",\n",
" \"SparseCategoricalCrossentropy\",\n",
" \"SquaredHinge\"\n",
"])"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"8it [21:53, 164.22s/it]\n"
]
}
],
"source": [
"accuracies = np.ndarray((len(optimizers), len(losses)))\n",
"for i, optimizer in tqdm(enumerate(optimizers)):\n",
" for j, loss in enumerate(losses):\n",
" model = tf.keras.Sequential(layers)\n",
" model.compile(\n",
" optimizer=optimizer,\n",
" loss=loss,\n",
" metrics=[\"accuracy\"]\n",
" )\n",
" callback = tf.keras.callbacks.EarlyStopping(monitor=\"accuracy\", patience=1)\n",
" model.fit(train_spikes, train_labels, epochs=6, callbacks=callback, verbose=0)\n",
" model_loss, model_acc = model.evaluate(test_spikes, test_labels, verbose=0)\n",
" accuracies[i][j] = model_acc"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Most accurate combination = Adam + BinaryCrossentropy\n",
"with accuracy = 0.8470870852470398\n"
]
}
],
"source": [
"best_idx = np.argmax(accuracies)\n",
"best_opt_idx = int(best_idx / len(losses))\n",
"best_loss_idx = best_idx % len(losses)\n",
"print(\"Most accurate combination = \", optimizers[best_opt_idx], \"+\", losses[best_loss_idx])\n",
"print(\"with accuracy = \", accuracies[best_opt_idx][best_loss_idx])"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.10.4 ('venv': venv)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.4"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "814c8425decae4f86b0a2793668b5d1e72243fbb280f353401e2c57732588a25"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}
170 changes: 170 additions & 0 deletions docs/discussion/spike_classification/Classifier_demo.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"datapath: str = \"./train_data_00.npz\"\n",
"import os\n",
"os.path.exists(datapath)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"from sklearn.utils import shuffle\n",
"\n",
"with np.load(datapath) as file:\n",
" labels, spikes = shuffle(file['label'], file['spike'])\n",
" sample_percentage = 1\n",
" cut = int(sample_percentage*len(labels))\n",
" labels = labels[:cut]\n",
" spikes = spikes[:cut]\n",
"\n",
"length = len(spikes[0])\n",
"split = int(len(labels) * 0.8)\n",
"train_labels = labels[:split]\n",
"train_spikes = spikes[:split]\n",
"test_labels = labels[split:]\n",
"test_spikes = spikes[split:]\n",
"del labels\n",
"del spikes"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"from miv.signal.classification.neuronal_spike_classification import NeuronalSpikeClassifier\n",
"classifier = NeuronalSpikeClassifier()"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"classifier.create_default_tf_keras_model(length)\n",
"classifier.default_compile_model()"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/5\n",
"5075/5075 [==============================] - 7s 1ms/step - loss: 1.2742 - accuracy: 0.7948\n",
"Epoch 2/5\n",
"5075/5075 [==============================] - 7s 1ms/step - loss: 0.5682 - accuracy: 0.8112\n",
"Epoch 3/5\n",
"5075/5075 [==============================] - 7s 1ms/step - loss: 0.5358 - accuracy: 0.8144\n",
"Epoch 4/5\n",
"5075/5075 [==============================] - 7s 1ms/step - loss: 0.5386 - accuracy: 0.8144\n",
"Epoch 5/5\n",
"5075/5075 [==============================] - 7s 1ms/step - loss: 0.5200 - accuracy: 0.8151\n"
]
}
],
"source": [
"classifier.default_train_model(train_spikes, train_labels)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1269/1269 [==============================] - 1s 1ms/step\n"
]
}
],
"source": [
"predictions = classifier.predict_categories_sigmoid(test_spikes)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[10144, 3851],\n",
" [ 2565, 24036]], dtype=int64)"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"classifier.get_confusion_matrix(test_spikes, test_labels)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.10.4 ('venv': venv)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.4"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "814c8425decae4f86b0a2793668b5d1e72243fbb280f353401e2c57732588a25"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Binary file not shown.
9 changes: 9 additions & 0 deletions miv/core/spikestamps.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,12 @@ def get_last_spikestamp(self):
def get_first_spikestamp(self):
"""Return timestamps of the first spike in this spikestamps"""
return min([data[0] for data in self.data if len(data) > 0])

def get_view(self, tstart: float, tend: float):
"""Truncate array and only includes spikestamps between tstart and tend."""
return Spikestamps(
[
np.array(sorted(list(filter(lambda x: tstart <= x <= tend, arr))))
for arr in self.data
]
)
3 changes: 3 additions & 0 deletions miv/signal/classification/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from miv.signal.classification.abnormality_detection import *
from miv.signal.classification.neuronal_spike_classification import *
from miv.signal.classification.protocol import *
Loading