-
Notifications
You must be signed in to change notification settings - Fork 0
/
an_compare_all.py
60 lines (42 loc) · 1.83 KB
/
an_compare_all.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import numpy as np
import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter1d
from utils import find_real_drift
n_features = [10,20,30]
n_drifts= [5,7,9,11]
methods = ['GNB', 'MCS-GNB', 'MLP', 'MCS-MLP', 'HTC', 'HTC-MCS']
cols = ['r', 'r', 'g', 'g', 'b', 'b']
lss = [':', '-', ':', '-', ':', '-']
res_clf = np.load('results_v4/res_compare_all.npy')
print(res_clf.shape) # features, n_drifts, drift_types, reps, methods, chunks-1
mean_res = np.mean(res_clf, axis=2)
print(mean_res.shape) # features, n_drifts, drift_types, methods, chunks-1
fig, axx = plt.subplots(4,3,figsize=(10,8), sharey=True)
for n_f_id, n_f in enumerate(n_features):
for d_id, d in enumerate(n_drifts):
drifts = find_real_drift(500, d)
ax = axx[d_id, n_f_id]
if d_id==0:
ax.set_title('%i features' % n_f, fontsize=12)
for m_id, m in enumerate(methods):
temp = gaussian_filter1d(mean_res[n_f_id,d_id,m_id], 1)
ax.plot(temp, label=m, c=cols[m_id], ls=lss[m_id])
ax.set_xticks(drifts, np.arange(1,12)[:len(drifts)])
if d_id==3:
ax.set_xlabel('index of drift', fontsize=12)
if n_f_id==0:
ax.set_ylabel('%i drifts \n $accuracy$' % d, fontsize=12)
ax.grid(ls=':')
box = ax.get_position()
ax.set_position([box.x0, box.y0 + box.height * 0.1,
box.width, box.height * 0.9])
axx.ravel()[-2].legend(loc='upper center', bbox_to_anchor=(0.5, -0.35),
frameon=False, ncol=3)
plt.subplots_adjust(left=0.07, right=0.93, wspace=-0.35, hspace=0.05)
for aa in axx.ravel():
aa.spines['top'].set_visible(False)
aa.spines['right'].set_visible(False)
plt.tight_layout()
plt.savefig('foo.png')
plt.savefig('figures/comare_all.png')
plt.savefig('figures/comare_all.eps')