Skip to content

Commit

Permalink
Fix Plots
Browse files Browse the repository at this point in the history
- add missing xlabel
- change "Mean Difference" to "Mean Distance"
- reduce plotting data
 - plot max 1000000 random positions to reduce runtime and prevent crashing
  • Loading branch information
JannesSP committed Aug 14, 2023
1 parent 30962ad commit a9a19cf
Showing 1 changed file with 16 additions and 10 deletions.
26 changes: 16 additions & 10 deletions magnipore/magnipore.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,10 +244,10 @@ def magnipore(mapping : dict, unaligned : dict, seq_dict : dict, aln_dict: dict,
num_indels, sign_pos, nans = 0, 0, 0

# TODO add some quality value
plotting_data = pd.DataFrame(columns=['Mean Difference', 'Avg Stdev', 'Strand', 'Mutational Context', 'Significant', 'TD Score', 'KL Divergence'])
plotting_data = pd.DataFrame(columns=['Mean Distance', 'Avg Stdev', 'Strand', 'Mutational Context', 'Significant', 'TD Score', 'KL Divergence'])
plotting_data = plotting_data.astype(
{
'Mean Difference': 'float32',
'Mean Distance': 'float32',
'Avg Stdev': 'float32',
'Strand': 'bool',
'Mutational Context': 'bool',
Expand Down Expand Up @@ -309,7 +309,7 @@ def magnipore(mapping : dict, unaligned : dict, seq_dict : dict, aln_dict: dict,
td = td_score(mDiff, sAvg)
significant = td>=1
new_entry = pd.DataFrame({
'Mean Difference' : [mDiff],
'Mean Distance' : [mDiff],
'Avg Stdev' : [sAvg],
'Strand' : [strand],
'Mutational Context' : [mut_context],
Expand Down Expand Up @@ -385,10 +385,15 @@ def plotStatistics(plotting_data : pd.DataFrame, working_dir : str, first_sample
plot_dir = os.path.join(working_dir, 'magnipore', f'{first_sample_label}_{sec_sample_label}', 'plots')
if not os.path.exists(plot_dir):
os.mkdir(plot_dir)
### Mean Dist vs Std Avg plot
# reduce plotting_data, if it got too large too reduce runtime and prevent the kernel from killing the process
plotting_threshold = 1000000 # arbitrary threshold
if len(plotting_data.index) > plotting_threshold:
LOGGER.printLog(f'The number of positions exceeds the threshold of {plotting_data} ({len(plotting_data.index)}). To prevent the kernel from killing the process, Magnipore will only plot a subset of {plotting_data} positions. Plots will not include the full data.')
plotting_data = plotting_data.sample(plotting_threshold, replace=False)
# Mean Dist vs Std Avg plot
LOGGER.printLog('Plotting Mean vs Stdev')
plotMeanDiffStdAvg(plotting_data, plot_dir, first_sample_label, sec_sample_label)
### plot scores
# plot scores
LOGGER.printLog(f'Plotting TD score and KL divergence')
plotScores(plotting_data, plot_dir, first_sample_label, sec_sample_label)

Expand Down Expand Up @@ -434,17 +439,17 @@ def plotMeanDiffStdAvg(dataframe : pd.DataFrame, working_dir : str, first_sample
label1 = first_sample_label.replace("_", " ")
label2 = sec_sample_label.replace("_", " ")

g = sns.JointGrid(x='Mean Difference', y='Avg Stdev', data=dataframe, hue='Mutational Context', marginal_ticks=True, palette=['blue', '#d95f02'], hue_order=[True, False], height = 10)
g = sns.JointGrid(x='Mean Distance', y='Avg Stdev', data=dataframe, hue='Mutational Context', marginal_ticks=True, palette=['blue', '#d95f02'], hue_order=[True, False], height = 10)
g.plot_joint(func=sns.scatterplot, s = 8)
g.ax_joint.cla()
for _, row in dataframe.iterrows():
g.ax_joint.plot(row['Mean Difference'], row['Avg Stdev'], color = color(row['Mutational Context']), marker = marker(row['Mutational Context']), markersize=3, alpha = 0.6)
g.ax_joint.plot(row['Mean Distance'], row['Avg Stdev'], color = color(row['Mutational Context']), marker = marker(row['Mutational Context']), markersize=3, alpha = 0.6)

g.fig.suptitle(f'{len(dataframe.index)} compared bases mean difference against\naverage standard deviation\n{label1} and {label2}', y=0.98)
g.fig.suptitle(f'{len(dataframe.index)} compared bases mean distance against\naverage standard deviation\n{label1} and {label2}', y=0.98)
g.ax_joint.grid(True, 'both', 'both', alpha = 0.4, linestyle = '--', linewidth = 0.5)

lims = np.array([
[-.02, max(dataframe['Mean Difference']) + 0.1],
[-.02, max(dataframe['Mean Distance']) + 0.1],
[-.02, max(dataframe['Avg Stdev']) + 0.1]
])

Expand All @@ -457,7 +462,8 @@ def plotMeanDiffStdAvg(dataframe : pd.DataFrame, working_dir : str, first_sample
g.ax_joint.set_xlim(tuple(lims[0]))
g.ax_joint.set_ylim(tuple(lims[1]))

g.ax_joint.set_ylabel('Average Standard Deviation')
g.ax_joint.set_ylabel('Average standard deviation')
g.ax_joint.set_xlabel('Mean distance')
legend_mut = mlines.Line2D([], [], color='blue', marker='D', linestyle='None', markersize=10, label='mutation')
legend_mod = mlines.Line2D([], [], color='#d95f02', marker='o', linestyle='None', markersize=10, label='matching reference')
sign = mlines.Line2D([], [], color='#1b9e77', marker='s', linestyle='None', markersize=10, label='significant, TD>=1')
Expand Down

0 comments on commit a9a19cf

Please sign in to comment.