Skip to content

Commit

Permalink
Update motivating_example.ipynb
Browse files Browse the repository at this point in the history
  • Loading branch information
cthoyt committed Jan 18, 2024
1 parent 6d9c6cf commit f7a33a2
Showing 1 changed file with 19 additions and 15 deletions.
34 changes: 19 additions & 15 deletions notebooks/motivating_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -42,20 +42,22 @@
},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import seaborn as sns\n",
"\n",
"from eliater.discover_latent_nodes import (\n",
" find_nuisance_variables,\n",
" mark_nuisance_variables_as_latent,\n",
" remove_nuisance_variables,\n",
")\n",
"from eliater.examples.frontdoor_backdoor_discrete import (\n",
" single_mediator_with_multiple_confounders_nuisances_discrete_example,\n",
")\n",
"from eliater.network_validation import print_graph_falsifications\n",
"from y0.algorithm.identify import Identification\n",
"from y0.dsl import P, Variable, X, Y\n",
"from y0.algorithm.estimation import estimate_ace\n",
"from y0.algorithm.identify import Identification, identify_outcomes\n",
"from eliater.discover_latent_nodes import find_nuisance_variables, mark_nuisance_variables_as_latent\n",
"from eliater.discover_latent_nodes import remove_nuisance_variables\n",
"import seaborn as sns\n",
"import matplotlib.pyplot as plt"
"from y0.dsl import P, Variable, X, Y"
]
},
{
Expand Down Expand Up @@ -529,9 +531,7 @@
}
],
"source": [
"ATE_value = estimate_ace(\n",
" graph=new_graph, treatments=X, outcomes=Y, data=data\n",
")\n",
"ATE_value = estimate_ace(graph=new_graph, treatments=X, outcomes=Y, data=data)\n",
"ATE_value"
]
},
Expand Down Expand Up @@ -599,7 +599,7 @@
" num_samples=500, seed=500, treatments={X: 0}\n",
")\n",
"\n",
"true_ate = intv_data_X_1.mean()[\"Y\"] - intv_data_X_0.mean()[\"Y\"] \n",
"true_ate = intv_data_X_1.mean()[\"Y\"] - intv_data_X_0.mean()[\"Y\"]\n",
"print(f\"The true ATE is {true_ate:.04}\")"
]
},
Expand Down Expand Up @@ -631,16 +631,20 @@
"source": [
"background_ates = []\n",
"for _ in range(500):\n",
" intv_data_X_1 = single_mediator_with_multiple_confounders_nuisances_discrete_example.generate_data(\n",
" num_samples=500, treatments={X: 1}\n",
" intv_data_X_1 = (\n",
" single_mediator_with_multiple_confounders_nuisances_discrete_example.generate_data(\n",
" num_samples=500, treatments={X: 1}\n",
" )\n",
" )\n",
" intv_data_X_0 = single_mediator_with_multiple_confounders_nuisances_discrete_example.generate_data(\n",
" num_samples=500, treatments={X: 0}\n",
" intv_data_X_0 = (\n",
" single_mediator_with_multiple_confounders_nuisances_discrete_example.generate_data(\n",
" num_samples=500, treatments={X: 0}\n",
" )\n",
" )\n",
" background_ate = intv_data_X_1.mean()[\"Y\"] - intv_data_X_0.mean()[\"Y\"]\n",
" background_ates.append(background_ate)\n",
"\n",
"fix, ax = plt.subplots(1,1, figsize=(5, 2))\n",
"fix, ax = plt.subplots(1, 1, figsize=(5, 2))\n",
"sns.histplot(background_ates, ax=ax)\n",
"ax.axvline(true_ate, linewidth=3, color=\"red\")\n",
"ax.axvline(ATE_value, linewidth=3, color=\"green\")\n",
Expand Down

0 comments on commit f7a33a2

Please sign in to comment.