Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix agender_agender feature #146

Merged
merged 10 commits into from
Jul 30, 2024
Merged
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,10 @@ All of them take *--config <my_config.ini>* as an argument.
* **nkululeko.nkululeko**: do machine learning experiments combining features and learners
* **nkululeko.ensemble**: [combine several nkululeko experiments](http://blog.syntheticspeech.de/2024/06/25/nkululeko-ensemble-classifiers-with-late-fusion/) and report on late fusion results
* *configurations*: which experiments to combine
* *--method* (optional): majority_voting, mean, max, sum, max_class, uncertainty_threshold, uncertainty_weighted, confidence_weighted
* *--method* (optional): majority_voting, mean (default), max, sum, uncertainty, uncertainty_weighted, confidence_weighted, performance_weighted
* *--threshold*: uncertainty threshold (1.0 means no threshold)
* *--outfile* (optional): name of CSV file for output
* *--weightes*: weights for performance_weighted method (could be from previous UAR, ACC)
* *--outfile* (optional): name of CSV file for output (default: ensemble_result.csv)
* *--no_labels* (optional): indicate that no ground truth is given
* **nkululeko.multidb**: do [multiple experiments](http://blog.syntheticspeech.de/2024/01/02/nkululeko-compare-several-databases/), comparing several databases cross and in itself
* **nkululeko.demo**: [demo the current best model](http://blog.syntheticspeech.de/2022/01/24/nkululeko-try-out-demo-a-trained-model/) on the command line
Expand Down
14 changes: 10 additions & 4 deletions data/banglaser/README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Nkululeko pre-processing for Bangla SER dataset

Download link: <https://prod-dcd-datasets-cache-zipfiles.s3.eu-west-1.amazonaws.com/t9h6p943xy-5.zip>
Filename convention:
Download link: <https://prod-dcd-datasets-cache-zipfiles.s3.eu-west-1.amazonaws.com/t9h6p943xy-5.zip>

Filename convention:
AA-BB-CC-DD-EE-FF-GG.wav

```bash
Expand All @@ -12,8 +12,14 @@ python3 process_database.py
cd ../..
python3 -m nkululeko.resample --config data/banglaser/exp.ini
python3 -m nkululeko.nkululeko --config data/banglaser/exp.ini

...
# sample outputs
DEBUG: reporter: Best score at epoch: 0, UAR: .681, (+-.611/.747), ACC: .672
DEBUG: reporter: labels: ['angry', 'neutral', 'sad', 'happy']
DEBUG: reporter: result per class (F1 score): [0.681, 0.494, 0.886, 0.674] from epoch: 0
DEBUG: experiment: Done, used 180.702 seconds
DONE
```

Reference:
[1]
[1] Das, R. K., Islam, N., Ahmed, M. R., Islam, S., Shatabda, S., & Islam, A. K. M. M. (2022). BanglaSER: A speech emotion recognition dataset for the Bangla language. Data in Brief, 42, 108091. https://doi.org/10.1016/j.dib.2022.108091
7 changes: 6 additions & 1 deletion data/gerparas/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,14 @@ Test speakers: gauland (102 samples) and weidel (122 samples).

Original audio source: [Deutscher Bundestag](https://www.bundestag.de/) in the year 2020.

All segments were then random spliced as described in the paper. The dataset, which is restricted, can be obtaine from [1]. See [2] for details.
All segments were then random spliced as described in the paper. The dataset, which is restricted, can be obtain from [1]. See [2] for details.

```bash
python3 process_database.py
cd ../..
python3 -m nkululeko.nkululeko --config data/gerparas/exp.ini
# output sample

```

Reference:
Expand Down
4 changes: 2 additions & 2 deletions ini_file.md
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@
* no_reuse = False
* **store_format**: how to store the features: possible values [pkl | csv]
* store_format = pkl
* **scale**: scale the features
* **scale**: scale the features (important for gmm)
* scale=standard
* possible values:
* **standard**: z-transformation (mean of 0 and std of 1) based on the training set
Expand Down Expand Up @@ -264,7 +264,7 @@
* **max_duration**: Max. duration of samples/segments for the transformer in seconds, frames are pooled.
* max_duration = 8.0
* **gmm**: Gaussian mixture classifier
* GMM_components = 4
* GMM_components = 4 (currently must be the same as number of labels)
* GMM_covariance_type = [full | tied | diag | spherical](https://scikit-learn.org/stable/auto_examples/mixture/plot_gmm_covariances.html)
* **knn**: k nearest neighbor classifier
* K_val = 5
Expand Down
21 changes: 21 additions & 0 deletions nkululeko/explore.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,24 @@
"""
Explore the feature sets of a machine learning experiment.

This script is the entry point for the 'explore' module of the nkululeko framework.
It handles loading the experiment configuration, setting up the experiment, and
running various feature exploration techniques based on the configuration.

The script supports the following configuration options:
- `no_warnings`: If set to `True`, it will ignore all warnings during the exploration.
- `feature_distributions`: If set to `True`, it will generate plots of the feature distributions.
- `tsne`: If set to `True`, it will generate a t-SNE plot of the feature space.
- `scatter`: If set to `True`, it will generate a scatter plot of the feature space.
- `spotlight`: If set to `True`, it will generate a 'spotlight' plot of the feature space.
- `shap`: If set to `True`, it will generate SHAP feature importance plots.
- `model`: The type of model to use for the feature exploration (e.g. 'SVM').
- `plot_tree`: If set to `True`, it will generate a decision tree plot.

The script can be run from the command line with the `--config` argument to specify
the configuration file to use. If no configuration file is provided, it will look
for an `exp.ini` file in the same directory as the script.
"""
# explore.py
# explore the feature sets

Expand Down
9 changes: 5 additions & 4 deletions nkululeko/feat_extract/feats_agender_agender.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,19 @@
import audonnx
import numpy as np
import audinterface
import torch


class AgenderAgenderSet(Featureset):
class Agender_agenderSet(Featureset):
"""
Age and gender predictions from the wav2vec2. based model finetuned on agender, described in the paper
"Speech-based Age and Gender Prediction with Transformers"
https://arxiv.org/abs/2306.16962
"""

def __init__(self, name, data_df):
super().__init__(name, data_df)
def __init__(self, name, data_df, feats_type):
super().__init__(name, data_df, feats_type)
self.model_loaded = False
self.feats_type = feats_type

def _load_model(self):
model_url = "https://zenodo.org/record/7761387/files/w2v2-L-robust-6-age-gender.25c844af-1.1.1.zip"
Expand Down
2 changes: 1 addition & 1 deletion nkululeko/feat_extract/feats_spkrec.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import torch
import torchaudio
from nkululeko.feat_extract.featureset import Featureset
from speechbrain.pretrained import EncoderClassifier
from speechbrain.inference import EncoderClassifier
from tqdm import tqdm

# from transformers import HubertModel, Wav2Vec2FeatureExtractor
Expand Down
23 changes: 20 additions & 3 deletions nkululeko/models/model_gmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from sklearn import mixture
from nkululeko.models.model import Model

import pandas as pd

class GMM_model(Model):
"""An GMM model"""
Expand All @@ -12,9 +12,26 @@ class GMM_model(Model):
def __init__(self, df_train, df_test, feats_train, feats_test):
super().__init__(df_train, df_test, feats_train, feats_test)
self.name = "gmm"
n_components = int(self.util.config_val("MODEL", "GMM_components", "4"))
self.n_components = int(self.util.config_val("MODEL", "GMM_components", "4"))
covariance_type = self.util.config_val("MODEL", "GMM_covariance_type", "full")
self.clf = mixture.GaussianMixture(
n_components=n_components, covariance_type=covariance_type
n_components=self.n_components,
covariance_type=covariance_type,
random_state = 42,
)
# set up the classifier

def get_predictions(self):
"""Use the predict_proba method of the GaussianMixture model to get
probabilities. Create a DataFrame with these probabilities and return
it along with the predictions."""
probs = self.clf.predict_proba(self.feats_test)
preds = self.clf.predict(self.feats_test)

# Convert predictions to a list
preds = preds.tolist()

# Create a DataFrame for probabilities
proba_df = pd.DataFrame(probs, index=self.feats_test.index, columns=range(self.n_components))

return preds, proba_df
7 changes: 7 additions & 0 deletions nkululeko/multidb.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
"""
Demonstrates the usage of the ML-experiment framework for the nkululeko MULTIDB project.

The `main` function is the entry point of the script, which parses command-line arguments, reads a configuration file, and runs the nkululeko or aug_train functions based on the configuration.

The `plot_heatmap` function generates a heatmap plot of the results and saves it to a file, along with some summary statistics.
"""
# main.py
# Demonstration code to use the ML-experiment framework

Expand Down
21 changes: 11 additions & 10 deletions nkululeko/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy import stats
import seaborn as sns
from scipy import stats
from sklearn.manifold import TSNE

import nkululeko.glob_conf as glob_conf
import nkululeko.utils.stats as su
from nkululeko.reporting.defines import Header
from nkululeko.reporting.report_item import ReportItem
import nkululeko.utils.stats as su
from nkululeko.utils.util import Util


Expand All @@ -32,9 +32,9 @@ def plot_distributions_speaker(self, df):
# plot the distribution of samples per speaker
# one up because of the runs
fig_dir = self.util.get_path("fig_dir") + "../"
self.util.debug(f"plotting samples per speaker")
self.util.debug("plotting samples per speaker")
if "gender" in df_speakers:
filename = f"samples_value_counts"
filename = "samples_value_counts"
ax = (
df_speakers.groupby("samplenum")["gender"]
.value_counts()
Expand All @@ -46,7 +46,7 @@ def plot_distributions_speaker(self, df):
rot=0,
)
)
ax.set_ylabel(f"number of speakers")
ax.set_ylabel("number of speakers")
ax.set_xlabel("number of samples")
self.save_plot(
ax,
Expand All @@ -58,7 +58,7 @@ def plot_distributions_speaker(self, df):

# fig.clear()
else:
filename = f"samples_value_counts"
filename = "samples_value_counts"
ax = (
df_speakers["samplenum"]
.value_counts()
Expand Down Expand Up @@ -265,7 +265,8 @@ def plotcatcont(self, df, cat_col, cont_col, xlab, ylab):
"""Plot relation of categorical distribution with continuous."""
dist_type = self.util.config_val("EXPL", "dist_type", "hist")
cats, cat_str, es = su.get_effect_size(df, cat_col, cont_col)
if dist_type == "hist":
model_type = self.util.get_model_type()
if dist_type == "hist" and model_type != "tree":
ax = sns.histplot(df, x=cont_col, hue=cat_col, kde=True)
caption = f"{ylab} {df.shape[0]}. {cat_str} ({cats}):" f" {es}"
ax.set_title(caption)
Expand Down Expand Up @@ -489,7 +490,7 @@ def scatter_plot(self, feats, label_df, label, dimred_type):
glob_conf.report.add_item(
ReportItem(
Header.HEADER_EXPLORE,
f"Scatter plot",
"Scatter plot",
f"using {dimred_type}",
filename,
)
Expand Down Expand Up @@ -561,8 +562,8 @@ def plot_tree(self, model, features):
glob_conf.report.add_item(
ReportItem(
Header.HEADER_EXPLORE,
f"Tree plot",
f"for feature importance",
"Tree plot",
"for feature importance",
filename,
)
)
30 changes: 30 additions & 0 deletions tests/exp_polish_bayes.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
[EXP]
root = ./tests/results/
name = exp_polish_os
[DATA]
databases = ['train', 'dev', 'test']
train = ./data/polish/polish_train.csv
train.type = csv
train.absolute_path = False
train.split_strategy = train
; train.audio_path = ./POLISH
dev = ./data/polish/polish_dev.csv
dev.type = csv
dev.absolute_path = False
dev.split_strategy = train
; dev.audio_path = ./POLISH
test = ./data/polish/polish_test.csv
test.type = csv
test.absolute_path = False
test.split_strategy = test
target = emotion
[FEATS]
type = ['os']
; type = ['hubert-xlarge-ll60k']
; no_reuse = False
; scale = standard
[MODEL]
type = bayes
; save = True
[RESAMPLE]
replace = True
36 changes: 36 additions & 0 deletions tests/exp_polish_gmm.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
[EXP]
root = /tmp/results/
name = exp_polish_os
[DATA]
databases = ['train', 'dev', 'test']
train = ./data/polish/polish_train.csv
train.type = csv
train.absolute_path = False
train.split_strategy = train
; train.audio_path = ./POLISH
dev = ./data/polish/polish_dev.csv
dev.type = csv
dev.absolute_path = False
dev.split_strategy = train
; dev.audio_path = ./POLISH
test = ./data/polish/polish_test.csv
test.type = csv
test.absolute_path = False
test.split_strategy = test
target = emotion
labels = ['fear', 'anger', 'neutral']
no_reuse = True
[FEATS]
type = ['os']
no_reuse = True
scale = standard
[MODEL]
type = gmm
GMM_components = 3
GMM_covariance_type = diag
; save = True
[RESAMPLE]
replace = True
[PLOT]
; do not plot anything

32 changes: 32 additions & 0 deletions tests/exp_polish_tree.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
[EXP]
root = ./tests/results/
name = exp_polish_os
[DATA]
databases = ['train', 'dev', 'test']
train = ./data/polish/polish_train.csv
train.type = csv
train.absolute_path = False
train.split_strategy = train
; train.audio_path = ./POLISH
dev = ./data/polish/polish_dev.csv
dev.type = csv
dev.absolute_path = False
dev.split_strategy = train
; dev.audio_path = ./POLISH
test = ./data/polish/polish_test.csv
test.type = csv
test.absolute_path = False
test.split_strategy = test
target = emotion
no_reuse = True
[FEATS]
type = ['os']
; set = eGeMAPSv02
; level = lld
no_reuse = True
; scale = standard
[MODEL]
type = tree
; save = True
[RESAMPLE]
replace = True
30 changes: 30 additions & 0 deletions tests/exp_ravdess_speaker.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
[EXP]
root = ./tests/results/
name = exp_ravdess_speaker
runs = 1
epochs = 1
save = True
[DATA]
databases = ['train', 'test']
train = ./data/ravdess/ravdess_speaker_train.csv
train.type = csv
train.absolute_path = False
train.split_strategy = train
test = ./data/ravdess/ravdess_speaker_test.csv
test.type = csv
test.absolute_path = False
test.split_strategy = test
target = speaker

labels = ['spk01', 'spk02', 'spk03', 'spk04', 'spk05', 'spk06', 'spk07', 'spk08', 'spk09', 'spk10', 'spk11', 'spk12', 'spk13', 'spk14', 'spk15', 'spk16', 'spk17', 'spk18', 'spk19', 'spk20', 'spk21', 'spk22', 'spk23', 'spk24']

[FEATS]
type = ['spkrec-ecapa-voxceleb']
no_reuse = False
scale = standard
[MODEL]
type = svm
C_val = 1.0
[RESAMPLE]
replace = True
sample_selection = all
Loading