Skip to content

Commit

Permalink
fix: empty score matrix breaks MatchMS
Browse files Browse the repository at this point in the history
  • Loading branch information
tharwood3 committed Apr 25, 2024
1 parent d5daa2f commit 90aef24
Showing 1 changed file with 39 additions and 2 deletions.
41 changes: 39 additions & 2 deletions metatlas/plots/dill2plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import pandas as pd
import dill
import numpy as np
import numpy.typing as npt
import json
import matplotlib.pyplot as plt

Expand Down Expand Up @@ -62,6 +63,7 @@
from io import StringIO

from matchms.similarity import CosineHungarian
from matchms.typing import SpectrumType
from matchms import Spectrum

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -2242,6 +2244,42 @@ def create_nonmatched_msms_hits(msms_data: pd.DataFrame, inchi_key: str) -> pd.D

return inchi_msms_hits


def score_matrix(cos: Type[CosineHungarian], references: List[SpectrumType], queries: List[SpectrumType]) -> npt.NDArray:
"""
Calculate matrix of scores and matching ion counts using MatchMS spectrum objects.
This is a replacement for (and is derived from) the native MatchMS BaseSimilarity.matrix method.
Source for is here: https://github.com/matchms/matchms/blob/master/matchms/similarity/BaseSimilarity.py
This additional function is necessary to fix numpy errors that arise when the score matrix is filled with all 0 values.
"""

score_datatype = [("score", "float"), ("matches", "int")]

n_rows = len(references)
n_cols = len(queries)
idx_row = []
idx_col = []
scores = []
for i_ref, reference in enumerate(references[:n_rows]):
for i_query, query in enumerate(queries[:n_cols]):
score = cos.pair(reference, query)

idx_row.append(i_ref)
idx_col.append(i_query)
scores.append(score)

idx_row = np.array(idx_row)
idx_col = np.array(idx_col)
scores_data = np.array(scores, dtype=score_datatype)

scores_array = np.zeros(shape=(n_rows, n_cols), dtype=score_datatype)
scores_array[idx_row, idx_col] = scores_data.reshape(-1)

return scores_array


def get_hits_per_compound(cos: Type[CosineHungarian], inchi_key: str,
msms_data: pd.DataFrame, msms_refs: pd.DataFrame) -> pd.DataFrame:
"""
Expand All @@ -2260,8 +2298,7 @@ def get_hits_per_compound(cos: Type[CosineHungarian], inchi_key: str,

filtered_msms_data = msms_data[msms_data['inchi_key']==inchi_key].reset_index(drop=True).drop(columns=['inchi_key', 'precursor_mz']).copy()

scores_matches = cos.matrix(filtered_msms_data.matchms_spectrum.tolist(),
filtered_msms_refs.matchms_spectrum.tolist())
scores_matches = score_matrix(cos, filtered_msms_data.matchms_spectrum.tolist(), filtered_msms_refs.matchms_spectrum.tolist())

inchi_msms_hits = pd.merge(filtered_msms_data, filtered_msms_refs.drop(columns=['name', 'adduct']), how='cross')
inchi_msms_hits['score'] = scores_matches['score'].flatten()
Expand Down

0 comments on commit 90aef24

Please sign in to comment.