Skip to content

Commit

Permalink
Merge pull request #1 from Howardhuang98/rl_method
Browse files Browse the repository at this point in the history
updated
  • Loading branch information
Howardhuang98 authored Nov 30, 2021
2 parents 700ef4d + 1460879 commit 2f18bd4
Show file tree
Hide file tree
Showing 20 changed files with 1,883 additions and 901 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ School: Tianjin University, Priceless Lab
## Here you can use:
* greedy hill climb
* dynamic program: shortest path perspective
* PC algorithm

## Project structure

Expand Down
9 changes: 9 additions & 0 deletions datasets/Asian expert.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
,smoke,bronc,lung,asia,tub,either,dysp,xray
smoke,0,0.9,0.9,0,0,0,0,0
bronc,0,0,0,0,0,0,0.9,0
lung,0,0,0,0,0,0.9,0,0
asia,0,0,0,0,0.9,0,0,0
tub,0,0,0,0,0,0.9,0,0
either,0,0,0,0,0,0,0.9,0.9
dysp,0,0,0,0,0,0,0,0
xray,0,0,0,0,0,0,0,0
92 changes: 92 additions & 0 deletions dlbn/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
"""
base class
"""

from abc import ABC
from abc import abstractmethod

import networkx as nx
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt



class Estimator(ABC):

def load_data(self, data):
"""
加载数据
"""

if isinstance(data, pd.DataFrame):
self.data = data
elif isinstance(data, np.ndarray):
data = pd.DataFrame(data, columns=[range(data.shape[0])])
self.data = data
else:
raise ValueError("Data loading error")

def show_est(self):
print("=========Estimator Information=========")
print('''
·▄▄▄▄ ▄▄▌ ▄▄▄▄· ▐ ▄
██▪ ██ ██• ▐█ ▀█▪ •█▌▐█
▐█· ▐█▌ ██▪ ▐█▀▀█▄ ▐█▐▐▌
██. ██ ▐█▌▐▌ ██▄▪▐█ ██▐█▌
▀▀▀▀▀• .▀▀▀ ·▀▀▀▀ ▀▀ █▪
''')
print(self.data.head(5))
print("Recover the BN with {} variables".format(len(self.data.columns)))

@abstractmethod
def run(self):
"""
run the estimator
"""

def show(self,):
if self.result_dag:
plt.figure()
nx.draw_networkx(self.result_dag)
plt.title("Bayesian network")
plt.show()
else:
raise ValueError("No result obtained")


class Score(ABC):
"""
Score base class
"""

def __init__(self, data: pd.DataFrame):
self.data = data
self.state_names = {}
for var in list(data.columns.values):
self.state_names[var] = sorted(list(self.data.loc[:, var].unique()))
self.contingency_table = None

@abstractmethod
def local_score(self, x, parents):
"""
return local score
"""

def all_score(self, dag, detail=True):
"""
return score on the DAG
"""
score_dict = {}
score_list = []
for node in dag.nodes:
parents = list(dag.predecessors(node))
local_score = self.local_score(node, parents)
score_list.append(local_score)
if detail:
score_dict[node] = local_score
if detail:
return sum(score_list), score_dict
return sum(score_list)


103 changes: 0 additions & 103 deletions dlbn/direct_graph.py

This file was deleted.

Loading

0 comments on commit 2f18bd4

Please sign in to comment.