Skip to content

Commit

Permalink
Find the best mmfs for distributions mixing using Kullback-Leibler di…
Browse files Browse the repository at this point in the history
…vergence
  • Loading branch information
LisIva committed May 22, 2024
1 parent 0a4ea75 commit 7b6d0bc
Showing 1 changed file with 158 additions and 38 deletions.
196 changes: 158 additions & 38 deletions smoothing_factor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import matplotlib.pyplot as plt
import matplotlib.colors
import matplotlib as mpl

# from matplotlib import colormaps

def get_sf(rf):
std_factor = 100 * np.std(csym_arr) - rf * np.max(csym_arr)
Expand All @@ -29,42 +29,162 @@ def add_array():
return new_array


# csym_arr = np.array([200, 250, 20, 30, 0.01])
# csym_arr = np.array([1, 20, 15, 30, 1])
# csym_arr = np.array([10, 20, 15, 30, 23])
csym_arr = np.array([1.2, 0.3, 2.1, 0.7, 0.8, 0.5, 0.1, 2.0, 1.0, 0.15, 0.23])
# csym_arr = np.array([-0.5166702263242204, 1e-06, 0.30847567753139465, -0.0021320685285910806, -0.013484108957650603, 1e-06, 0.0004800419368503562, 0.0005433419101957565, 0.0033599844393899943, 1e-06, 1e-06, 1e-06, -0.00027104834002254167, -0.001684295171560161, 1.1097813793237595e-05, 1e-06, 1e-06, 1e-06, -4.3438843718743564e-07, 2.1009290401237286e-06, -3.2298730017636938e-06, 1e-06, 1e-06, 1e-06, 1.6266320830728536e-06])

csym_arr = np.fabs(csym_arr)
maxi = np.max(csym_arr)
mini = np.min(csym_arr)

mmf = 2.4
min_max_coeff = np.max(csym_arr) - mmf*np.min(csym_arr)
smoothing_factor = min_max_coeff / (min_max_coeff + (mmf-1) * np.average(csym_arr))

uniform_csym = np.array([np.sum(csym_arr) / len(csym_arr)] * len(csym_arr))
smoothed_array = (1-smoothing_factor) * csym_arr + smoothing_factor * uniform_csym
final_probabilities = smoothed_array / np.sum(smoothed_array)
simple_aver = csym_arr / np.sum(csym_arr)
no_normal = smoothed_array

cmap = mpl.cm.get_cmap('cividis')
norm = matplotlib.colors.Normalize(vmin=min(csym_arr)-0.2*min(csym_arr), vmax=max(csym_arr)+0.2*max(csym_arr))
colors = cmap(norm(csym_arr))

# fig, ax = plt.subplots(figsize=(20, 20))
fig, ax = plt.subplots(figsize=(16, 8))
ax.set_ylim(0, np.max(final_probabilities) + 0.01)
sns.barplot(x=np.arange(len(csym_arr)), y=final_probabilities, orient="v", ax=ax, palette=colors)
ax.set_yticks([0., 1.])
ax.set_yticklabels([0, 1])

plt.yticks(fontsize=200)
plt.grid()
plt.title(np.std(smoothed_array) / np.max(smoothed_array) * 100)
plt.title(f"Smoothing factor: {smoothing_factor:.3f}")
plt.show()
# plt.savefig('civi_norm.png', transparent=True)
def calc_prob(csym_arr, mmf=3.0):
csym_arr = np.fabs(csym_arr)
min_max_coeff = np.max(csym_arr) - mmf * np.min(csym_arr)
smoothing_factor = min_max_coeff / (min_max_coeff + (mmf - 1) * np.average(csym_arr))

uniform_csym = np.array([np.sum(csym_arr) / len(csym_arr)] * len(csym_arr))
smoothed_array = (1 - smoothing_factor) * csym_arr + smoothing_factor * uniform_csym
return smoothed_array / np.sum(smoothed_array)


# ['Pastel1', 'Pastel2', 'Paired', 'Accent', 'Dark2', 'Set1', 'Set2', 'Set3', 'tab10', 'tab20', 'tab20b', 'tab20c'])
def plot(csym_arr):
cmap = mpl.cm.get_cmap('tab20')
norm = matplotlib.colors.Normalize(vmin=min(csym_arr) - 0.2 * min(csym_arr),
vmax=max(csym_arr) + 0.2 * max(csym_arr))
colors = cmap(norm(csym_arr))

fig, ax = plt.subplots(figsize=(16, 8))
# ax.set_ylim(0, np.max(final_probabilities) + 0.01)
sns.barplot(x=np.arange(len(csym_arr)), y=csym_arr / np.sum(csym_arr), orient="v", ax=ax, palette=colors)
# ax.set_yticks([0., 1.])
# ax.set_yticklabels([0, 1])
# plt.yticks(fontsize=200)
plt.grid()
plt.show()
# plt.savefig('civi_norm.png', transparent=True)


def calc_kl(symn_distr, ideal_distr):
return np.sum(ideal_distr*np.log(ideal_distr / symn_distr))


################ kdv_sindy ################ ideal 3.0
# csym_ideal = [0.3, 1.5, 0.3, 0.3, 1.5, 0.3, 1.5, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3]

# 0%: 1.4
# csym_ls = [0.00013317180943934418, 0.462545700475963, -0.8913131300529831, 0.0013318871769012945, 0.033778270898942905, 1e-06, 0.0003630221131877745, -0.0003721569836748165, -0.0008534115707586299, 1e-06, 1e-06, 1e-06, -0.00028500296474465853, -0.00019525685837003688, -0.0005890799551753088]

# 25% 1.4
# csym_ls = [-0.0001293801315590306, 0.46248752519025266, -0.8914188507595079, -0.00029770137936806446, 0.03355619962099748, 1e-06, 0.00011641090470355867, 0.00021790578562142643, 4.396431886170122e-05, 1e-06, 1e-06, 1e-06, 0.0005026998071588287, -0.00025276791534035244, 0.00010830538102359453]# 50% 1.4

# 50%: 1.4
# csym_ls = [0.0002381475418994724, 0.46226309385704617, -0.8904106491977929, 0.001054075960887105, 0.03411553851629949, 1e-06, -0.0015553574670286198, -0.0010058326244412636, 0.0003338272190270323, 1e-06, 1e-06, 1e-06, -0.0026314026825166427, 0.0006456291858422701, 0.00038857257310549736]

# 75%: 3.3 kl! - 0.0459
# csym_ls = [0.00026640691887069206, 0.46271247307315194, -0.8912386944442275, -0.0006158038852539438, 0.0341862517020763, 1e-06, 0.0003595406815783831, 0.0018753700304558968, -0.0004868756828956544, 1e-06, 1e-06, 1e-06, -0.0006377581471452651, -0.0005412319774752918, 0.0027144138766298782]

# 100%:
# csym_ls = [-0.0003243016920022109, 0.46284821240263196, -0.8912387867895476, 0.0007263872116419843, 0.034457638015716316, 1e-06, 0.00023488650716849385, -0.0010765213036412142, -0.0019783785864441765, 1e-06, 1e-06, 1e-06, -0.001446893439064628, -0.00040456822956662386, 0.0008374295628705653]


######################## kdv ######################## ideal 3.5
# csym_ideal = [0.3, 1.5, 0.3, 0.3, 1.5, 1.5, 0.3, 1.5, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3]
# 0%: 4.7
# csym_ls = [-0.5738747299324107, 1.9409932356425736, -0.11654156765055264, -0.07046287746560881, -1.0212144868330735, 1.032042117248668, 1e-06, -2.3497580511626914, -0.2794104186792921, 0.11785213646185881, 0.029564732783981536, 1e-06, 1e-06, 1e-06, 1e-06, 0.3492224377043156, -0.13732691469459216, -0.05736074929210076, 0.04375671254047099, -0.07784762367293639, 0.037625163196757835]

# 25%: 4.6
# csym_ls = [-0.6226895519635444, 2.059189242411157, -0.25990141725526117, -0.09255353453936892, -1.0308850723806922, 1.0263609752377283, 1e-06, -2.362048628903625, -0.2347217544911788, 0.2198358341268513, -0.06935756617383874, 1e-06, 1e-06, 1e-06, 1e-06, 0.37563869354288004, -0.3301659607478368, 0.08413687396354587, -0.02749269514926165, 0.017197084512367556, -0.009970393980948858]

# 50%: 4.5
# csym_ls = [-0.46078477462054285, 2.2704157258087294, -0.02254640500282605, -0.05283098761069975, -1.0144895824220672, 1.0389887384148027, 1e-06, -2.6785248498562333, -0.25115178191156073, 0.039946912853136826, 0.04755320454356934, 1e-06, 1e-06, 1e-06, 1e-06, 0.5771951788539833, -0.1366054267518474, 0.04043457649866549, 0.03220780678502397, -0.07806429852643527, 0.024531951361348142]

# 75%: 4.5
# csym_ls = [-0.5517790443652698, 1.8288663300338377, 0.08831887363598773, -0.008036031293322934, -1.0277282517623498, 1.0600779063464922, 1e-06, -2.107357186490851, -0.2090304907493231, -0.007240772674735302, 0.23531775085988, 1e-06, 1e-06, 1e-06, 1e-06, 0.4073260831687099, -0.025907989028107578, -0.2591269348813302, 0.06002539421718861, -0.18718574744051725, 0.05897196698067864]

# 100%: 4.1
# csym_ls = [-0.611807256334479, 2.091573573749484, -0.03502775059760545, -0.02471404489900008, -1.028565097167191, 1.0388048999412942, 1e-06, -2.1819574719150188, -0.33605326307870365, 0.2055232524241674, -0.011441006498573061, 1e-06, 1e-06, 1e-06, 1e-06, 0.6929180258977231, -0.33289892936537885, 0.05173891202131046, -0.07169107134204616, 0.0011062010753618307, -0.008829736762602198]
# kl < 0.05

######################## burg_s ######################## ideal 3.0
# csym_ideal = [0.3, 1.5, 0.3, 1.5, 0.3, 1.5, 0.3, 0.3, 0.3, 0.3]
# 0%: 3.2
# csym_ls = [0.005460569789379852, 0.5387707061029996, -0.0024549705290218106, 0.09444975430284994, 1e-06, -0.9830916579031492, -0.012947013718093547, 1e-06, 1e-06, -0.03281432540987663]

# 25%: 3.2
# csym_ls = [0.03787527904294921, 0.4939875079417061, -0.023765529791406115, 0.07811680204532033, 1e-06, -0.8524158052588215, -0.04208106982942887, 1e-06, 1e-06, -0.05093147579808352]

# 50%: 3.1
# csym_ls = [0.05352525022936795, 0.4925431732500289, -0.03360749634996277, 0.06837170520432455, 1e-06, -0.7638492542123676, -0.08414698392229335, 1e-06, 1e-06, -0.06366918480163249]

# 75%: 2.8
# csym_ls = [0.016178502189279916, 0.4309538468067937, -0.032884557235337744, 0.059976475495745904, 1e-06, -0.5933034380209061, -0.17749882997856425, 1e-06, 1e-06, -0.15215799047086387]

# 100%: 2.8
# csym_ls = [0.07419864346609416, 0.4573251475230543, -0.06614196245044304, 0.04445120113138302, 1e-06, -0.5975513976960619, -0.07755442017334822, 1e-06, 1e-06, -0.05575691426237836]
# kl < 0.07

######################## burgers ######################## ideal 3.0
# csym_ideal = [0.3, 1.5, 0.3, 0.3, 1.5, 0.3]
# 0%: 3.5
# csym_ls = [-0.00349112838018091, 0.5740381641888227, -0.14501829929682208, 1e-06, -1.0030580290808233, 1e-06]

# 25%: 3.4
# csym_ls = [-0.006532582717331897, 0.6508152205067945, -0.2801774548165729, 1e-06, -1.021452986197016, 1e-06]

# 50%: 3.4
# csym_ls = [2.2340578143178005e-05, 0.6768692093225501, 0.14415629801550106, 1e-06, -1.0025505969774686, 1e-06]

# 75%: 2.6 - 3.6
# csym_ls = [-0.012041642402178008, 1.1473938732374185, -0.9184598254725471, 1e-06, -1.0292540103892314, 1e-06]
# csym_ls = [0.030731734503859236, 0.5378339975533124, -0.14870655029529314, 1e-06, -0.9099623168305407, 1e-06]
# csym_ls = [-0.008590791125349218, 0.5649451658146871, 0.10302964293542353, 1e-06, -1.0268606886939506, 1e-06]

# 100%: 3.5
# csym_ls = [-0.0019257019095356431, 0.5493267689398174, 0.09538386915583418, 1e-06, -1.0032696687238005, 1e-06]



######################## wave ######################## ideal 3.0
csym_ideal = [0.3, 0.3, 1.5, 0.3, 1.5]
# 0%: 1.1
# csym_ls = [-0.287399004879559, -0.16320653549114902, 0.30994707053525344, -0.038707991119449836, 0.02758466904541761]\

# 25%: 1.0
# csym_ls = [-0.43033962715279256, 0.6488653494822086, 0.5533729523848405, -0.004534756559710221, 0.027650062187502576]

# 50%:
# csym_ls = [1.1442625147834207, 0.7179028644857391, 0.024608024185436127, 0.0006107594501504804, -0.0026897761639400903]

# 75%:
# csym_ls = [1.1316718856213246, 0.6863988037639855, -0.01910488566292144, -0.00012651579854263187, -0.0018741658849076542]

# 100%:
csym_ls = [1.3795505591347466, 0.771918351854544, 0.05970507409859573, -0.0008519145298678077, -0.0016470272978029558]

mmfs = [1.0 + 0.1*incr for incr in range(50)]
kl_ls = []
for mmf in mmfs:
sym_distr = calc_prob(csym_ls, mmf)
id_distr = calc_prob(csym_ideal, 3.0)
val = calc_kl(sym_distr, id_distr)
kl_ls.append(val)
dct = dict(zip(mmfs, kl_ls))
print()

























0 comments on commit 7b6d0bc

Please sign in to comment.