-
Notifications
You must be signed in to change notification settings - Fork 0
/
BNetwork.py
31 lines (22 loc) · 865 Bytes
/
BNetwork.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import pomegranate as pom
import pandas as pd
import pickle
'''clmns = ['UNIX_TS','CWE','DWE','FRE','HPE','WOE','CDE','EBE','FGE','HTE','TVE']
#pomegranate.utils.enable_gpu()
print(pom.utils.is_gpu_enabled())
#train = pd.read_csv("Electricity_P_Thinned_Hourly_MinmaxNorm_Train.csv")
#test = pd.read_csv("Electricity_P_Thinned_Hourly_MinmaxNorm_Test.csv")
df = pd.read_csv("Electricity_P_Thinned_Hourly_MeanImp.csv")
df_tr = df.loc[:, ['UNIX_TS','TVE']]
print("Starting Training")
model = pom.BayesianNetwork.from_samples(df_tr, algorithm='exact', state_names = clmns)
#model = pom.NaiveBayes.from_samples(df_tr)
model.bake()
json = model.to_json()
file = open("BNet.json", 'w+')
file.write(json)
file.close()'''
model = pom.BayesianNetwork.from_json('BNet.json')
#print(model.probability(['861.0', None]))
print(model.marginal())
print(model.state_count())