Skip to content

Commit

Permalink
linting fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
VSumanth99 committed Jun 17, 2024
1 parent 64b6238 commit aabb1a8
Show file tree
Hide file tree
Showing 43 changed files with 919 additions and 843 deletions.
2 changes: 1 addition & 1 deletion configs/synthetic/mcd.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# hyperparameters
watch_gradients: false
num_epochs: 1
num_epochs: 10000
model: mcd
monitor_checkpoint_based_on: likelihood

Expand Down
32 changes: 9 additions & 23 deletions src/baselines/BaselineTrainer.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@

from typing import Any
import lightning.pytorch as pl
import torch.nn as nn
import torch
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.data import DataLoader
import torch.nn.functional as F
import numpy as np

from src.dataset.BaselineTSDataset import BaselineTSDataset
import numpy as np
from src.utils.metrics_utils import mape_loss

class BaselineTrainer(pl.LightningModule):
Expand All @@ -21,7 +17,6 @@ def __init__(self,
lag: int,
num_workers: int = 16,
aggregated_graph: bool = False):

super().__init__()

self.num_workers = num_workers
Expand All @@ -35,29 +30,20 @@ def __init__(self,
assert adj_matrices.shape[0] == self.total_samples

self.full_dataset = BaselineTSDataset(
X = self.full_dataset_np,
adj_matrix = self.adj_matrices_np,
lag=lag,
X = self.full_dataset_np,
adj_matrix = self.adj_matrices_np,
lag=lag,
aggregated_graph=self.aggregated_graph,
return_graph_indices=True
)

self.batch_size = 1

def forward(self):
raise NotImplementedError
def compute_mse(self, x_current, x_pred):
return F.mse_loss(x_current, x_pred)

def compute_mse(self, X_current, X_pred):
return F.mse_loss(X_current, X_pred)

def compute_mape(self, X_current, X_pred):
return mape_loss(X_current, X_pred)

def training_step(self, batch, batch_idx):
raise NotImplementedError
def compute_mape(self, x_current, x_pred):
return mape_loss(x_current, x_pred)

def get_full_dataloader(self) -> DataLoader:
return DataLoader(self.full_dataset, batch_size=self.batch_size, num_workers=self.num_workers)

def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
raise NotImplementedError
20 changes: 7 additions & 13 deletions src/baselines/DYNOTEARSTrainer.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from typing import Any
from src.baselines.BaselineTrainer import BaselineTrainer
import numpy as np
import torch

from src.utils.data_utils.data_format_utils import to_time_aggregated_graph_np, to_time_aggregated_scores_np, zero_out_diag_np, zero_out_diag_torch
# import tigramite for pcmci
from src.modules.dynotears.dynotears import from_pandas_dynamic
import networkx as nx
import pandas as pd

from src.baselines.BaselineTrainer import BaselineTrainer
from src.modules.dynotears.dynotears import from_pandas_dynamic
from src.utils.data_utils.data_format_utils import to_time_aggregated_graph_np, to_time_aggregated_scores_np, zero_out_diag_np

class DYNOTEARSTrainer(BaselineTrainer):

def __init__(self,
Expand Down Expand Up @@ -43,7 +44,6 @@ def __init__(self,
num_workers=num_workers,
aggregated_graph=aggregated_graph)


self.max_iter = max_iter
self.lambda_w = lambda_w
self.lambda_a = lambda_a
Expand All @@ -58,7 +58,7 @@ def __init__(self,
def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
X, adj_matrix, graph_index = batch

batch_size, timesteps, num_nodes, data_dim = X.shape
batch_size, timesteps, num_nodes, _ = X.shape
assert num_nodes == self.num_nodes
X = X.view(batch_size, timesteps, -1)

Expand All @@ -68,10 +68,8 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> A

graphs = np.zeros((batch_size, self.lag+1, num_nodes, num_nodes))
scores = np.zeros((batch_size, self.lag+1, num_nodes, num_nodes))

if self.group_by_graph:
n_unique_matrices = np.max(graph_index)+1
unique_matrices = np.unique(adj_matrix, axis=0)
else:
graph_index = np.zeros((batch_size))
n_unique_matrices = 1
Expand All @@ -81,7 +79,6 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> A
n_samples = np.sum(graph_index == i)
for x in X[graph_index == i]:
X_list.append(pd.DataFrame(x))

learner = from_pandas_dynamic(
X_list,
p=self.lag,
Expand All @@ -102,16 +99,13 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> A
# scores = np.hstack(temporal_adj_list)
temporal_adj = [(score != 0).astype(int) for _ in range(n_samples)]
score = [np.abs(score) for _ in range(n_samples)]

graphs[i == graph_index] = np.array(temporal_adj)
scores[i == graph_index] = np.array(score)


if self.aggregated_graph:
graphs = to_time_aggregated_graph_np(graphs)
scores = to_time_aggregated_scores_np(scores)
if self.ignore_self_connections:
graphs = zero_out_diag_np(graphs)
scores = zero_out_diag_np(scores)
return torch.Tensor(graphs), torch.abs(torch.Tensor(scores)), torch.Tensor(adj_matrix)
return torch.Tensor(graphs), torch.abs(torch.Tensor(scores)), torch.Tensor(adj_matrix)

30 changes: 10 additions & 20 deletions src/baselines/PCMCITrainer.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
from typing import Any
from src.baselines.BaselineTrainer import BaselineTrainer
from copy import deepcopy

import numpy as np
from src.utils.data_utils.data_format_utils import to_time_aggregated_graph_np, zero_out_diag_np, zero_out_diag_torch
# import tigramite for pcmci
import tigramite
from tigramite import data_processing as pp
from tigramite.pcmci import PCMCI
from tigramite.independence_tests.parcorr import ParCorr
from tigramite.independence_tests.cmiknn import CMIknn
from copy import deepcopy
import torch
from src.utils.causality_utils import *

from src.utils.causality_utils import convert_temporal_to_static_adjacency_matrix, cpdag2dags
from src.baselines.BaselineTrainer import BaselineTrainer
from src.utils.data_utils.data_format_utils import to_time_aggregated_graph_np, zero_out_diag_np

"""
Large parts adapted from https://github.com/microsoft/causica
Expand All @@ -33,7 +34,6 @@ def __init__(self,
group_by_graph: bool = False,
ignore_self_connections: bool = False
):

self.group_by_graph = group_by_graph
self.ignore_self_connections = ignore_self_connections
if self.group_by_graph:
Expand Down Expand Up @@ -97,8 +97,7 @@ def _process_adj_matrix(self, adj_matrix: np.ndarray) -> np.ndarray:
def _run_pcmci(self, pcmci, tau_max, pc_alpha):
if self.pcmci_plus:
return pcmci.run_pcmciplus(tau_max=tau_max, pc_alpha=pc_alpha)
else:
return pcmci.run_pcmci(tau_max=tau_max, pc_alpha=pc_alpha)
return pcmci.run_pcmci(tau_max=tau_max, pc_alpha=pc_alpha)

def _process_cpdag(self, adj_matrix: np.ndarray):
"""
Expand All @@ -109,7 +108,6 @@ def _process_cpdag(self, adj_matrix: np.ndarray):
Returns:
adj_matrix: np.ndarray with shape [num_possible_dags, lag+1, num_nodes, num_nodes]
"""

lag_plus, num_nodes = adj_matrix.shape[0], adj_matrix.shape[1]
static_temporal_graph = convert_temporal_to_static_adjacency_matrix(
adj_matrix, conversion_type="auto_regressive"
Expand All @@ -130,22 +128,17 @@ def _process_cpdag(self, adj_matrix: np.ndarray):

def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
X, adj_matrix, graph_index = batch

batch_size, timesteps, num_nodes, data_dim = X.shape
batch_size, timesteps, num_nodes, _ = X.shape
assert num_nodes == self.num_nodes

X = X.view(batch_size, timesteps, -1)
X, adj_matrix, graph_index = X.numpy(), adj_matrix.numpy(), graph_index.numpy()

graphs = [] #np.zeros((batch_size, self.lag+1, num_nodes, num_nodes))
new_adj_matrix = []
if self.group_by_graph:
n_unique_matrices = np.max(graph_index)+1
else:
graph_index = np.zeros((batch_size))
n_unique_matrices = 1

unique_matrices = np.unique(adj_matrix, axis=0)
for i in range(n_unique_matrices):
print(f"{i}/{n_unique_matrices}")
n_samples = np.sum(graph_index == i)
Expand All @@ -156,13 +149,9 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> A
verbosity=0)

results = self._run_pcmci(pcmci, self.lag, self.pc_alpha)

graph = self._process_adj_matrix(results["graph"])

graph = self._process_cpdag(graph)

num_possible_dags = graph.shape[0]

new_adj_matrix.append(np.repeat(adj_matrix[graph_index==i][0][np.newaxis, ...], n_samples*num_possible_dags, axis=0))
graphs.append(np.repeat(graph, n_samples, axis=0))

Expand All @@ -173,4 +162,5 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> A
if self.ignore_self_connections:
graphs = zero_out_diag_np(graphs)

return torch.Tensor(graphs), torch.Tensor(graphs), torch.Tensor(new_adj_matrix)
return torch.Tensor(graphs), torch.Tensor(graphs), torch.Tensor(new_adj_matrix)

15 changes: 5 additions & 10 deletions src/baselines/VARLiNGAMTrainer.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import Any
from src.baselines.BaselineTrainer import BaselineTrainer
import numpy as np

import lingam
from src.utils.data_utils.data_format_utils import to_time_aggregated_graph_np
import torch

from src.utils.data_utils.data_format_utils import to_time_aggregated_graph_np
from src.baselines.BaselineTrainer import BaselineTrainer

class VARLiNGAMTrainer(BaselineTrainer):

def __init__(self,
Expand All @@ -17,7 +18,6 @@ def __init__(self,
num_workers: int = 16,
aggregated_graph: bool = False
):

super().__init__(full_dataset=full_dataset,
adj_matrices=adj_matrices,
data_dim=data_dim,
Expand All @@ -27,28 +27,23 @@ def __init__(self,
aggregated_graph=aggregated_graph)

def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
X, adj_matrix, graph_index = batch
X, adj_matrix, _ = batch

batch, timesteps, num_nodes, data_dim = X.shape
batch, timesteps, num_nodes, _ = X.shape
X = X.view(batch, timesteps, -1)

assert num_nodes == self.num_nodes
assert batch == 1, "VARLiNGAM needs batch size 1"

model_pruned = lingam.VARLiNGAM(lags=self.lag, prune=True)
model_pruned.fit(X[0])

graph = np.transpose(np.abs(model_pruned.adjacency_matrices_) > 0, axes=[0, 2, 1])

if graph.shape[0] != (self.lag+1):
while graph.shape[0] != (self.lag+1):
graph = np.concatenate((graph, np.zeros((1, num_nodes, num_nodes) )), axis=0)

graphs = [graph]

if self.aggregated_graph:
graphs = to_time_aggregated_graph_np(graphs)

print(graphs)
print(adj_matrix)
return torch.Tensor(graphs), torch.Tensor(graphs), torch.Tensor(adj_matrix)
20 changes: 7 additions & 13 deletions src/dataset/BaselineTSDataset.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
from torch.utils.data import Dataset
from src.utils.utils import *
from src.utils.data_utils.data_format_utils import *
import torch
from src.utils.data_utils.data_format_utils import get_adj_matrix_id

class BaselineTSDataset(Dataset):
# Dataset class that can be used with the baselines PCMCI(+), VARLiNGAM and DYNOTEARS

def __init__(self,
X,
adj_matrix,
lag,
def __init__(self,
X,
adj_matrix,
lag,
aggregated_graph=False,
return_graph_indices=False):
"""
Expand All @@ -18,11 +15,9 @@ def __init__(self,
"""
self.lag = lag
self.aggregated_graph = aggregated_graph

self.X = X
self.adj_matrix = adj_matrix
self.return_graph_indices = return_graph_indices

if self.return_graph_indices:
self.unique_matrices, self.matrix_indices = get_adj_matrix_id(self.adj_matrix)

Expand All @@ -32,6 +27,5 @@ def __len__(self):
def __getitem__(self, index):
if not self.return_graph_indices:
return self.X[index], self.adj_matrix[index]
else:
return self.X[index], self.adj_matrix[index], self.matrix_indices[index]

return self.X[index], self.adj_matrix[index], self.matrix_indices[index]

Loading

0 comments on commit aabb1a8

Please sign in to comment.