diff --git a/configs/synthetic/mcd.yaml b/configs/synthetic/mcd.yaml index 7f09265..2c07490 100644 --- a/configs/synthetic/mcd.yaml +++ b/configs/synthetic/mcd.yaml @@ -1,6 +1,6 @@ # hyperparameters watch_gradients: false -num_epochs: 1 +num_epochs: 10000 model: mcd monitor_checkpoint_based_on: likelihood diff --git a/src/baselines/BaselineTrainer.py b/src/baselines/BaselineTrainer.py index 68ae29b..238ab58 100644 --- a/src/baselines/BaselineTrainer.py +++ b/src/baselines/BaselineTrainer.py @@ -1,13 +1,9 @@ - -from typing import Any import lightning.pytorch as pl -import torch.nn as nn -import torch -from torch.utils.data import DataLoader, TensorDataset +from torch.utils.data import DataLoader import torch.nn.functional as F +import numpy as np from src.dataset.BaselineTSDataset import BaselineTSDataset -import numpy as np from src.utils.metrics_utils import mape_loss class BaselineTrainer(pl.LightningModule): @@ -21,7 +17,6 @@ def __init__(self, lag: int, num_workers: int = 16, aggregated_graph: bool = False): - super().__init__() self.num_workers = num_workers @@ -35,29 +30,20 @@ def __init__(self, assert adj_matrices.shape[0] == self.total_samples self.full_dataset = BaselineTSDataset( - X = self.full_dataset_np, - adj_matrix = self.adj_matrices_np, - lag=lag, + X = self.full_dataset_np, + adj_matrix = self.adj_matrices_np, + lag=lag, aggregated_graph=self.aggregated_graph, return_graph_indices=True ) self.batch_size = 1 - def forward(self): - raise NotImplementedError + def compute_mse(self, x_current, x_pred): + return F.mse_loss(x_current, x_pred) - def compute_mse(self, X_current, X_pred): - return F.mse_loss(X_current, X_pred) - - def compute_mape(self, X_current, X_pred): - return mape_loss(X_current, X_pred) - - def training_step(self, batch, batch_idx): - raise NotImplementedError + def compute_mape(self, x_current, x_pred): + return mape_loss(x_current, x_pred) def get_full_dataloader(self) -> DataLoader: return DataLoader(self.full_dataset, batch_size=self.batch_size, num_workers=self.num_workers) - - def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: - raise NotImplementedError \ No newline at end of file diff --git a/src/baselines/DYNOTEARSTrainer.py b/src/baselines/DYNOTEARSTrainer.py index 6e15051..c7a0eb8 100644 --- a/src/baselines/DYNOTEARSTrainer.py +++ b/src/baselines/DYNOTEARSTrainer.py @@ -1,14 +1,15 @@ from typing import Any -from src.baselines.BaselineTrainer import BaselineTrainer import numpy as np import torch -from src.utils.data_utils.data_format_utils import to_time_aggregated_graph_np, to_time_aggregated_scores_np, zero_out_diag_np, zero_out_diag_torch # import tigramite for pcmci -from src.modules.dynotears.dynotears import from_pandas_dynamic import networkx as nx import pandas as pd +from src.baselines.BaselineTrainer import BaselineTrainer +from src.modules.dynotears.dynotears import from_pandas_dynamic +from src.utils.data_utils.data_format_utils import to_time_aggregated_graph_np, to_time_aggregated_scores_np, zero_out_diag_np + class DYNOTEARSTrainer(BaselineTrainer): def __init__(self, @@ -43,7 +44,6 @@ def __init__(self, num_workers=num_workers, aggregated_graph=aggregated_graph) - self.max_iter = max_iter self.lambda_w = lambda_w self.lambda_a = lambda_a @@ -58,7 +58,7 @@ def __init__(self, def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: X, adj_matrix, graph_index = batch - batch_size, timesteps, num_nodes, data_dim = X.shape + batch_size, timesteps, num_nodes, _ = X.shape assert num_nodes == self.num_nodes X = X.view(batch_size, timesteps, -1) @@ -68,10 +68,8 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> A graphs = np.zeros((batch_size, self.lag+1, num_nodes, num_nodes)) scores = np.zeros((batch_size, self.lag+1, num_nodes, num_nodes)) - if self.group_by_graph: n_unique_matrices = np.max(graph_index)+1 - unique_matrices = np.unique(adj_matrix, axis=0) else: graph_index = np.zeros((batch_size)) n_unique_matrices = 1 @@ -81,7 +79,6 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> A n_samples = np.sum(graph_index == i) for x in X[graph_index == i]: X_list.append(pd.DataFrame(x)) - learner = from_pandas_dynamic( X_list, p=self.lag, @@ -102,16 +99,13 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> A # scores = np.hstack(temporal_adj_list) temporal_adj = [(score != 0).astype(int) for _ in range(n_samples)] score = [np.abs(score) for _ in range(n_samples)] - graphs[i == graph_index] = np.array(temporal_adj) scores[i == graph_index] = np.array(score) - - if self.aggregated_graph: graphs = to_time_aggregated_graph_np(graphs) scores = to_time_aggregated_scores_np(scores) if self.ignore_self_connections: graphs = zero_out_diag_np(graphs) scores = zero_out_diag_np(scores) - - return torch.Tensor(graphs), torch.abs(torch.Tensor(scores)), torch.Tensor(adj_matrix) \ No newline at end of file + return torch.Tensor(graphs), torch.abs(torch.Tensor(scores)), torch.Tensor(adj_matrix) + \ No newline at end of file diff --git a/src/baselines/PCMCITrainer.py b/src/baselines/PCMCITrainer.py index 2f88579..be3e067 100644 --- a/src/baselines/PCMCITrainer.py +++ b/src/baselines/PCMCITrainer.py @@ -1,16 +1,17 @@ from typing import Any -from src.baselines.BaselineTrainer import BaselineTrainer +from copy import deepcopy + import numpy as np -from src.utils.data_utils.data_format_utils import to_time_aggregated_graph_np, zero_out_diag_np, zero_out_diag_torch # import tigramite for pcmci -import tigramite from tigramite import data_processing as pp from tigramite.pcmci import PCMCI from tigramite.independence_tests.parcorr import ParCorr from tigramite.independence_tests.cmiknn import CMIknn -from copy import deepcopy import torch -from src.utils.causality_utils import * + +from src.utils.causality_utils import convert_temporal_to_static_adjacency_matrix, cpdag2dags +from src.baselines.BaselineTrainer import BaselineTrainer +from src.utils.data_utils.data_format_utils import to_time_aggregated_graph_np, zero_out_diag_np """ Large parts adapted from https://github.com/microsoft/causica @@ -33,7 +34,6 @@ def __init__(self, group_by_graph: bool = False, ignore_self_connections: bool = False ): - self.group_by_graph = group_by_graph self.ignore_self_connections = ignore_self_connections if self.group_by_graph: @@ -97,8 +97,7 @@ def _process_adj_matrix(self, adj_matrix: np.ndarray) -> np.ndarray: def _run_pcmci(self, pcmci, tau_max, pc_alpha): if self.pcmci_plus: return pcmci.run_pcmciplus(tau_max=tau_max, pc_alpha=pc_alpha) - else: - return pcmci.run_pcmci(tau_max=tau_max, pc_alpha=pc_alpha) + return pcmci.run_pcmci(tau_max=tau_max, pc_alpha=pc_alpha) def _process_cpdag(self, adj_matrix: np.ndarray): """ @@ -109,7 +108,6 @@ def _process_cpdag(self, adj_matrix: np.ndarray): Returns: adj_matrix: np.ndarray with shape [num_possible_dags, lag+1, num_nodes, num_nodes] """ - lag_plus, num_nodes = adj_matrix.shape[0], adj_matrix.shape[1] static_temporal_graph = convert_temporal_to_static_adjacency_matrix( adj_matrix, conversion_type="auto_regressive" @@ -130,13 +128,10 @@ def _process_cpdag(self, adj_matrix: np.ndarray): def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: X, adj_matrix, graph_index = batch - - batch_size, timesteps, num_nodes, data_dim = X.shape + batch_size, timesteps, num_nodes, _ = X.shape assert num_nodes == self.num_nodes - X = X.view(batch_size, timesteps, -1) X, adj_matrix, graph_index = X.numpy(), adj_matrix.numpy(), graph_index.numpy() - graphs = [] #np.zeros((batch_size, self.lag+1, num_nodes, num_nodes)) new_adj_matrix = [] if self.group_by_graph: @@ -144,8 +139,6 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> A else: graph_index = np.zeros((batch_size)) n_unique_matrices = 1 - - unique_matrices = np.unique(adj_matrix, axis=0) for i in range(n_unique_matrices): print(f"{i}/{n_unique_matrices}") n_samples = np.sum(graph_index == i) @@ -156,13 +149,9 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> A verbosity=0) results = self._run_pcmci(pcmci, self.lag, self.pc_alpha) - graph = self._process_adj_matrix(results["graph"]) - graph = self._process_cpdag(graph) - num_possible_dags = graph.shape[0] - new_adj_matrix.append(np.repeat(adj_matrix[graph_index==i][0][np.newaxis, ...], n_samples*num_possible_dags, axis=0)) graphs.append(np.repeat(graph, n_samples, axis=0)) @@ -173,4 +162,5 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> A if self.ignore_self_connections: graphs = zero_out_diag_np(graphs) - return torch.Tensor(graphs), torch.Tensor(graphs), torch.Tensor(new_adj_matrix) \ No newline at end of file + return torch.Tensor(graphs), torch.Tensor(graphs), torch.Tensor(new_adj_matrix) + \ No newline at end of file diff --git a/src/baselines/VARLiNGAMTrainer.py b/src/baselines/VARLiNGAMTrainer.py index 2fe63d1..367e903 100644 --- a/src/baselines/VARLiNGAMTrainer.py +++ b/src/baselines/VARLiNGAMTrainer.py @@ -1,11 +1,12 @@ from typing import Any -from src.baselines.BaselineTrainer import BaselineTrainer import numpy as np import lingam -from src.utils.data_utils.data_format_utils import to_time_aggregated_graph_np import torch +from src.utils.data_utils.data_format_utils import to_time_aggregated_graph_np +from src.baselines.BaselineTrainer import BaselineTrainer + class VARLiNGAMTrainer(BaselineTrainer): def __init__(self, @@ -17,7 +18,6 @@ def __init__(self, num_workers: int = 16, aggregated_graph: bool = False ): - super().__init__(full_dataset=full_dataset, adj_matrices=adj_matrices, data_dim=data_dim, @@ -27,9 +27,9 @@ def __init__(self, aggregated_graph=aggregated_graph) def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: - X, adj_matrix, graph_index = batch + X, adj_matrix, _ = batch - batch, timesteps, num_nodes, data_dim = X.shape + batch, timesteps, num_nodes, _ = X.shape X = X.view(batch, timesteps, -1) assert num_nodes == self.num_nodes @@ -37,18 +37,13 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> A model_pruned = lingam.VARLiNGAM(lags=self.lag, prune=True) model_pruned.fit(X[0]) - graph = np.transpose(np.abs(model_pruned.adjacency_matrices_) > 0, axes=[0, 2, 1]) - if graph.shape[0] != (self.lag+1): while graph.shape[0] != (self.lag+1): graph = np.concatenate((graph, np.zeros((1, num_nodes, num_nodes) )), axis=0) - graphs = [graph] - if self.aggregated_graph: graphs = to_time_aggregated_graph_np(graphs) - print(graphs) print(adj_matrix) return torch.Tensor(graphs), torch.Tensor(graphs), torch.Tensor(adj_matrix) \ No newline at end of file diff --git a/src/dataset/BaselineTSDataset.py b/src/dataset/BaselineTSDataset.py index 97b7606..031ff24 100644 --- a/src/dataset/BaselineTSDataset.py +++ b/src/dataset/BaselineTSDataset.py @@ -1,15 +1,12 @@ from torch.utils.data import Dataset -from src.utils.utils import * -from src.utils.data_utils.data_format_utils import * -import torch +from src.utils.data_utils.data_format_utils import get_adj_matrix_id class BaselineTSDataset(Dataset): # Dataset class that can be used with the baselines PCMCI(+), VARLiNGAM and DYNOTEARS - - def __init__(self, - X, - adj_matrix, - lag, + def __init__(self, + X, + adj_matrix, + lag, aggregated_graph=False, return_graph_indices=False): """ @@ -18,11 +15,9 @@ def __init__(self, """ self.lag = lag self.aggregated_graph = aggregated_graph - self.X = X self.adj_matrix = adj_matrix self.return_graph_indices = return_graph_indices - if self.return_graph_indices: self.unique_matrices, self.matrix_indices = get_adj_matrix_id(self.adj_matrix) @@ -32,6 +27,5 @@ def __len__(self): def __getitem__(self, index): if not self.return_graph_indices: return self.X[index], self.adj_matrix[index] - else: - return self.X[index], self.adj_matrix[index], self.matrix_indices[index] - \ No newline at end of file + return self.X[index], self.adj_matrix[index], self.matrix_indices[index] + \ No newline at end of file diff --git a/src/dataset/FragmentDataset.py b/src/dataset/FragmentDataset.py index d47b4cb..2ad1619 100644 --- a/src/dataset/FragmentDataset.py +++ b/src/dataset/FragmentDataset.py @@ -1,25 +1,24 @@ """ Terminology: -A fragment refers to the pair of tensors X_history and X_current, where X_history represents -the lag information (X(t-lag) to X(t-1)) and X_current represents the current information X(t). +A fragment refers to the pair of tensors X_history and x_current, where X_history represents +the lag information (X(t-lag) to X(t-1)) and x_current represents the current information X(t). Note that which sample a fragment comes from is irrelevant, since all we are concerned about is the causal graph which generated the fragment. """ from torch.utils.data import Dataset -from src.utils.utils import * -from src.utils.data_utils.data_format_utils import * import torch +from src.utils.data_utils.data_format_utils import convert_data_to_timelagged, convert_adj_to_timelagged -class FragmentDataset(Dataset): - def __init__(self, - X, - adj_matrix, - lag, - return_graph_indices=True, +class FragmentDataset(Dataset): + def __init__(self, + X, + adj_matrix, + lag, + return_graph_indices=True, aggregated_graph=False): """ X: np.array of shape (n_samples, timesteps, num_nodes, data_dim) @@ -29,7 +28,7 @@ def __init__(self, self.aggregated_graph = aggregated_graph self.return_graph_indices = return_graph_indices # preprocess data - self.X_history, self.X_current, self.X_indices = convert_data_to_timelagged( + self.X_history, self.x_current, self.X_indices = convert_data_to_timelagged( X, lag=lag) if self.return_graph_indices: self.adj_matrix, self.graph_indices = convert_adj_to_timelagged( @@ -46,9 +45,9 @@ def __init__(self, aggregated_graph=self.aggregated_graph, return_indices=False) - self.X_history, self.X_current, self.adj_matrix, self.X_indices = \ - torch.Tensor(self.X_history), torch.Tensor(self.X_current), torch.Tensor(self.adj_matrix), torch.Tensor(self.X_indices) - + self.X_history, self.x_current, self.adj_matrix, self.X_indices = \ + torch.Tensor(self.X_history), torch.Tensor(self.x_current), torch.Tensor( + self.adj_matrix), torch.Tensor(self.X_indices) if self.return_graph_indices: self.graph_indices = torch.Tensor(self.graph_indices) @@ -59,7 +58,6 @@ def __len__(self): def __getitem__(self, index): if self.return_graph_indices: - return self.X_history[index], self.X_current[index], self.adj_matrix[index], self.X_indices[index].long(), self.graph_indices[index].long() - else: - return self.X_history[index], self.X_current[index], self.adj_matrix[index], self.X_indices[index].long() - \ No newline at end of file + return self.X_history[index], self.x_current[index], self.adj_matrix[index], self.X_indices[index].long(), \ + self.graph_indices[index].long() + return self.X_history[index], self.x_current[index], self.adj_matrix[index], self.X_indices[index].long() diff --git a/src/model/BaseTrainer.py b/src/model/BaseTrainer.py index e05b673..5b2563e 100644 --- a/src/model/BaseTrainer.py +++ b/src/model/BaseTrainer.py @@ -1,16 +1,14 @@ import lightning.pytorch as pl -import torch.nn as nn import torch from torch.utils.data import DataLoader, TensorDataset import torch.nn.functional as F +import numpy as np -from torch.optim import Adam -from torch.optim.lr_scheduler import ExponentialLR, ReduceLROnPlateau, ConstantLR from src.dataset.FragmentDataset import FragmentDataset -import numpy as np from src.utils.metrics_utils import mape_loss + class BaseTrainer(pl.LightningModule): def __init__(self, @@ -43,7 +41,6 @@ def __init__(self, self.adj_matrices_np = adj_matrices[I] self.use_all_for_val = use_all_for_val - self.val_frac = val_frac assert self.val_frac >= 0.0 and self.val_frac < 1.0, "Validation fraction should be between 0 and 1" @@ -55,69 +52,64 @@ def __init__(self, print("Using *all* examples for validation. Ignoring val_frac...") self.train_dataset_np = self.full_dataset_np self.val_dataset_np = self.full_dataset_np - self.train_adj_np = self.adj_matrices_np self.val_adj_np = self.adj_matrices_np else: - self.train_dataset_np = self.full_dataset_np[:int((1-self.val_frac)*num_samples)] - self.val_dataset_np = self.full_dataset_np[int((1-self.val_frac)*num_samples):] - - self.train_adj_np = self.adj_matrices_np[:int((1-self.val_frac)*num_samples)] - self.val_adj_np = self.adj_matrices_np[int((1-self.val_frac)*num_samples):] - + self.train_dataset_np = self.full_dataset_np[:int( + (1-self.val_frac)*num_samples)] + self.val_dataset_np = self.full_dataset_np[int( + (1-self.val_frac)*num_samples):] + self.train_adj_np = self.adj_matrices_np[:int( + (1-self.val_frac)*num_samples)] + self.val_adj_np = self.adj_matrices_np[int( + (1-self.val_frac)*num_samples):] self.train_frag_dataset = FragmentDataset( - self.train_dataset_np, - self.train_adj_np, - lag=lag, + self.train_dataset_np, + self.train_adj_np, + lag=lag, aggregated_graph=self.aggregated_graph, return_graph_indices=self.return_graph_indices) - self.val_frag_dataset = FragmentDataset( - self.val_dataset_np, - self.val_adj_np, - lag=lag, + self.val_dataset_np, + self.val_adj_np, + lag=lag, aggregated_graph=self.aggregated_graph, return_graph_indices=self.return_graph_indices) - # self.full_frag_dataset = FragmentDataset( - # self.full_dataset_np, - # self.adj_matrices_np, - # lag=lag, + # self.full_dataset_np, + # self.adj_matrices_np, + # lag=lag, # aggregated_graph=self.aggregated_graph, # return_graph_indices=self.return_graph_indices) - self.num_fragments = len(self.train_frag_dataset) - self.full_dataset = TensorDataset( - torch.Tensor(self.full_dataset_np), - torch.Tensor(self.adj_matrices_np), - torch.arange(self.full_dataset_np.shape[0])) - + torch.Tensor(self.full_dataset_np), + torch.Tensor(self.adj_matrices_np), + torch.arange(self.full_dataset_np.shape[0])) if self.batch_size is None: # do full-batch training self.batch_size = self.num_fragments - + def forward(self): raise NotImplementedError - def compute_loss(self, X_history, X_current, X_full, adj_matrix): + def compute_loss(self, X_history, x_current, X_full, adj_matrix): raise NotImplementedError - def compute_mse(self, X_current, X_pred): - return F.mse_loss(X_current, X_pred) + def compute_mse(self, x_current, x_pred): + return F.mse_loss(x_current, x_pred) - def compute_mape(self, X_current, X_pred): - return mape_loss(X_current, X_pred) + def compute_mape(self, x_current, x_pred): + return mape_loss(x_current, x_pred) def training_step(self, batch, batch_idx): - X_history, X_current, X_full, adj_matrix = batch - loss = self.compute_loss(X_history, X_current, X_full, adj_matrix) + X_history, x_current, X_full, adj_matrix = batch + loss = self.compute_loss(X_history, x_current, X_full, adj_matrix) return loss def validation_step(self, batch, batch_idx): - X_history, X_current, X_full, adj_matrix = batch - loss = self.compute_loss(X_history, X_current, X_full, adj_matrix) - + X_history, x_current, X_full, adj_matrix = batch + loss = self.compute_loss(X_history, x_current, X_full, adj_matrix) self.log("val_loss", loss) def train_dataloader(self) -> DataLoader: @@ -132,8 +124,8 @@ def get_full_dataloader(self) -> DataLoader: def track_gradients(self, m, log_name): total_norm = 0 for p in m.parameters(): - if p.grad != None: + if p.grad is not None: param_norm = p.grad.data.norm(2) total_norm += param_norm.item() ** 2 total_norm = total_norm ** (1. / 2) - self.log(log_name, total_norm) \ No newline at end of file + self.log(log_name, total_norm) diff --git a/src/model/MCDTrainer.py b/src/model/MCDTrainer.py index 8dbeed0..c63c702 100644 --- a/src/model/MCDTrainer.py +++ b/src/model/MCDTrainer.py @@ -1,25 +1,17 @@ -import lightning.pytorch as pl -import torch.nn as nn +from torch import nn import torch -import torch.nn.functional as F +from sklearn.metrics import accuracy_score +import numpy as np -from src.modules.MultiEmbedding import MultiEmbedding from src.model.RhinoTrainer import RhinoTrainer -import numpy as np from src.utils.metrics_utils import cluster_accuracy from src.utils.loss_utils import dag_penalty_notears, temporal_graph_sparsity from src.utils.data_utils.data_format_utils import to_time_aggregated_scores_torch, zero_out_diag_torch, get_adj_matrix_id from src.modules.adjacency_matrices.MultiTemporalAdjacencyMatrix import MultiTemporalAdjacencyMatrix from src.modules.MixtureSelectionLogits import MixtureSelectionLogits -from sklearn.metrics import accuracy_score, roc_auc_score -import math -from src.training.auglag import AugLagLRConfig, AugLagLR, AugLagLossCalculator - -class MCDTrainer(RhinoTrainer): - """ - """ +class MCDTrainer(RhinoTrainer): def __init__(self, full_dataset: np.array, adj_matrices: np.array, @@ -29,10 +21,9 @@ def __init__(self, num_graphs: int, causal_decoder, tcsf, - - use_correct_mixture_index: bool = False, # diagnostic option to check if + use_correct_mixture_index: bool = False, # diagnostic option to check if # the graph can be learnt when the correct mixture index is given - use_true_graph: bool = False, # diagnostic option to check if + use_true_graph: bool = False, # diagnostic option to check if # the correct mixture index is learnt when the correct graph is given likelihood_loss: str = 'flow', ignore_self_connections: bool = False, @@ -44,12 +35,12 @@ def __init__(self, threeway_graph_dist: bool = True, skip_auglag_epochs: int = 0, training_procedure: str = 'auglag', - training_config = None, - init_logits = [0, 0], - disable_inst = False, + training_config=None, + init_logits=[0, 0], + disable_inst=False, graph_selection_prior_lambda: float = 0.0, - use_indices = None, + use_indices=None, log_num_unique_graphs=False, use_all_for_val=False, shuffle=True @@ -82,18 +73,14 @@ def __init__(self, training_procedure=training_procedure, disable_inst=disable_inst ) - # initialize point-wise logits self.mixture_selection = MixtureSelectionLogits(num_graphs=self.num_graphs, num_samples=self.total_samples) - self.use_correct_mixture_index = use_correct_mixture_index self.use_true_graph = use_true_graph - # if use correct graph is set, have the set of correct graphs ready if self.use_true_graph: - self.true_graphs = torch.Tensor(np.unique(adj_matrices, axis=0)) - + self.true_graphs = torch.Tensor(np.unique(adj_matrices, axis=0)) self.use_indices = use_indices def initialize_graph(self): @@ -106,14 +93,14 @@ def initialize_graph(self): init_logits=self.init_logits, disable_inst=self.disable_inst ) - + def set_mixture_indices(self, indices): self.use_indices = indices def forward(self): raise NotImplementedError - - def compute_loss(self, X_history, X_current, idx): + + def compute_loss(self, X_history, x_current, idx): # first, get the mixture assignment probabilities for each point mixture_probs = self.mixture_selection.get_probs(idx) # mixture_probs shape: (num_graphs, batch) @@ -124,85 +111,95 @@ def compute_loss(self, X_history, X_current, idx): # next, sample G from variational distributions G = self.graphs.sample_A() # G shape: (num_graphs, lag+1, num_nodes, num_nodes) - - loss_terms, component_wise_likelihood = self.compute_loss_terms( X_history=X_history, - X_current=X_current, + x_current=x_current, G=G, sample_idx=idx, mixture_probs=mixture_probs) - + total_loss = loss_terms['likelihood'] + \ - loss_terms['graph_prior'] + \ - loss_terms['graph_entropy'] + \ - loss_terms['graph_selection_entropy'] + \ - loss_terms['graph_selection_prior'] - + loss_terms['graph_prior'] + \ + loss_terms['graph_entropy'] + \ + loss_terms['graph_selection_entropy'] + \ + loss_terms['graph_selection_prior'] + return total_loss, loss_terms, component_wise_likelihood - - def compute_loss_terms(self, - X_history: torch.Tensor, - X_current: torch.Tensor, + + def compute_loss_terms(self, + X_history: torch.Tensor, + x_current: torch.Tensor, G: torch.Tensor, sample_idx: torch.Tensor, mixture_probs: torch.Tensor, ): - batch, num_nodes, data_dim = X_current.shape + batch, _, _ = x_current.shape # *********** graph prior ***************** - + # sparsity term graph_sparsity_term = self.sparsity_factor * \ temporal_graph_sparsity(G) # dagness factors - dagness_penalty = dag_penalty_notears(G[:, 0]) + dagness_penalty = dag_penalty_notears(G[:, 0]) graph_prior_term = graph_sparsity_term graph_prior_term /= self.num_fragments # ************* graph entropy term ************ graph_entropy_term = -self.graphs.entropy()/self.num_fragments - + # ********* graph selection prior term ******** # weight term ce_loss = nn.CrossEntropyLoss() - W = -(self.graph_selection_prior_lambda * torch.arange(self.num_graphs, device=self.device)).float() - graph_selection_prior_term = ce_loss(W.unsqueeze(1).repeat(1, batch).transpose(0, 1), self.mixture_selection.get_probs(sample_idx).transpose(0, 1))/(self.num_fragments)*self.total_samples + W = -(self.graph_selection_prior_lambda * torch.arange(self.num_graphs, + device=self.device)).float() + graph_selection_prior_term = ce_loss(W.unsqueeze(1).repeat(1, batch).transpose( + 0, 1), self.mixture_selection.get_probs(sample_idx).transpose(0, 1))/(self.num_fragments)*self.total_samples # graph_selection_prior_term /= batch # ********** graph selection logits entropy *** - graph_selection_entropy_term = - self.mixture_selection.entropy(sample_idx)/(self.num_fragments)*self.total_samples + graph_selection_entropy_term = - \ + self.mixture_selection.entropy( + sample_idx)/(self.num_fragments)*self.total_samples # ************* likelihood loss **************** - X_input = torch.cat((X_history, X_current.unsqueeze(1)), dim=1) - X_pred = self.causal_decoder(X_input, G) - # X_pred shape: (batch, num_graphs, num_nodes, data_dim) - + X_input = torch.cat((X_history, x_current.unsqueeze(1)), dim=1) + x_pred = self.causal_decoder(X_input, G) + # x_pred shape: (batch, num_graphs, num_nodes, data_dim) + if self.likelihood_loss == 'mse': - X_current = X_current.unsqueeze(1).expand((-1, self.num_graphs, -1, -1)) - component_wise_likelihood = torch.sum(torch.square(X_pred - X_current), dim=(2, 3)) - likelihood_term = mixture_probs.transpose(0, -1) * component_wise_likelihood + x_current = x_current.unsqueeze(1).expand( + (-1, self.num_graphs, -1, -1)) + component_wise_likelihood = torch.sum( + torch.square(x_pred - x_current), dim=(2, 3)) + likelihood_term = mixture_probs.transpose( + 0, -1) * component_wise_likelihood likelihood_term = torch.sum(torch.mean(likelihood_term, dim=0)) elif self.likelihood_loss == 'flow': - X_current = X_current.unsqueeze(1).expand((-1, self.num_graphs, -1, -1)) - log_prob = self.tcsf.log_prob(X_input=(X_current - X_pred).reshape(batch*self.num_graphs, -1), - X_history=X_history, + x_current = x_current.unsqueeze(1).expand( + (-1, self.num_graphs, -1, -1)) + log_prob = self.tcsf.log_prob(X_input=(x_current - x_pred).reshape(batch*self.num_graphs, -1), + X_history=X_history, A=G).view(batch, self.num_graphs, -1).sum(-1) # weight the likelihood term by the mixture selection probabilities # mixture_probs.shape: (num_graphs, batch), log_prob.shape: (batch, self.num_graphs) - component_wise_likelihood = log_prob * mixture_probs.transpose(0, -1) - likelihood_term = -torch.sum(torch.mean(component_wise_likelihood, dim=0)) - + component_wise_likelihood = log_prob * \ + mixture_probs.transpose(0, -1) + likelihood_term = - \ + torch.sum(torch.mean(component_wise_likelihood, dim=0)) + mixture_index = torch.argmax(mixture_probs, dim=0) - mse_loss = self.compute_mse(X_current[:, mixture_index], X_pred[:, mixture_index]) - mape_loss = self.compute_mape(X_current[:, mixture_index], X_pred[:, mixture_index]) - + mse_loss = self.compute_mse( + x_current[:, mixture_index], x_pred[:, mixture_index]) + mape_loss = self.compute_mape( + x_current[:, mixture_index], x_pred[:, mixture_index]) + # ************************************************ loss_terms = { 'graph_sparsity': graph_sparsity_term, @@ -218,56 +215,59 @@ def compute_loss_terms(self, 'likelihood': likelihood_term } return loss_terms, component_wise_likelihood - + def training_step(self, batch, batch_idx): - X_history, X_current, adj_matrix, idx, mixture_idx = batch - + X_history, x_current, adj_matrix, idx, mixture_idx = batch + if self.use_correct_mixture_index: self.mixture_selection.manual_set_mixture_indices(idx, mixture_idx) - if self.use_indices != None: + if self.use_indices is not None: self.use_indices = self.use_indices.to(self.device).long() - self.mixture_selection.manual_set_mixture_indices(idx, self.use_indices[idx]) - - loss, loss_terms, _ = self.compute_loss(X_history, X_current, idx) + self.mixture_selection.manual_set_mixture_indices( + idx, self.use_indices[idx]) + + loss, loss_terms, _ = self.compute_loss(X_history, x_current, idx) self.track_gradients(self.mixture_selection, "mixture_grad") - + if self.log_num_unique_graphs: - graph_index = torch.argmax(self.mixture_selection.get_probs(idx), dim=0) - self.log('num_unique_graphs', float(torch.unique(graph_index).shape[0])) + graph_index = torch.argmax( + self.mixture_selection.get_probs(idx), dim=0) + self.log('num_unique_graphs', float( + torch.unique(graph_index).shape[0])) loss_terms['loss'] = loss return loss_terms # convert nan gradients to zero def on_after_backward(self) -> None: for p in self.parameters(): - if p.grad != None: + if p.grad is not None: p.grad = torch.nan_to_num(p.grad) return super().on_after_backward() - - def update_validation_indices(self, X_history, X_current, idx): + + def update_validation_indices(self, X_history, x_current, idx): with torch.inference_mode(False): with torch.enable_grad(): optimizer = self.optimizers() - + for param in self.parameters(): param.requires_grad_(False) self.mixture_selection.requires_grad_(True) loss, _, _ = self.compute_loss(X_history=X_history, - X_current=X_current, - idx=idx) + x_current=x_current, + idx=idx) optimizer.zero_grad() - loss.backward() + loss.backward() optimizer.step() - + for param in self.parameters(): param.requires_grad_(True) def validation_step(self, batch, batch_idx): - X_history, X_current, adj_matrix, idx, mixture_idx = batch + X_history, x_current, adj_matrix, idx, mixture_idx = batch # use the correct offset if not self.use_all_for_val: @@ -277,26 +277,29 @@ def validation_step(self, batch, batch_idx): self.mixture_selection.manual_set_mixture_indices(idx, mixture_idx) elif not self.use_all_for_val: # update the mixture indices for the validation dataset - self.update_validation_indices(X_history, X_current, idx) - + self.update_validation_indices(X_history, x_current, idx) + # first select graph for each sample - graph_index = torch.argmax(self.mixture_selection.get_probs(idx), dim=0) + graph_index = torch.argmax( + self.mixture_selection.get_probs(idx), dim=0) if self.use_true_graph: G = self.true_graphs.to(self.device)[graph_index] else: G = self.graphs.sample_A()[graph_index] - - loss_terms = self.validation_func(X_history, X_current, adj_matrix, G, idx) + + loss_terms = self.validation_func( + X_history, x_current, adj_matrix, G, idx) # if we are using the correct graph, evaluate the accuracy with which # the correct mixture index is being learnt if self.use_true_graph: - pred_mixture_idx = self.mixture_selection.get_mixture_indices(idx).detach().cpu().numpy() + pred_mixture_idx = self.mixture_selection.get_mixture_indices( + idx).detach().cpu().numpy() true_idx = mixture_idx.detach().cpu().numpy() mixture_idx_acc = accuracy_score(true_idx, pred_mixture_idx) loss_terms['mixture_idx_acc'] = mixture_idx_acc self.log('mixture_idx_acc', mixture_idx_acc) - + if not (self.use_true_graph or self.use_correct_mixture_index): # log the cluster accuracy true_idx = mixture_idx.detach().cpu().numpy() @@ -305,7 +308,8 @@ def validation_step(self, batch, batch_idx): self.log('cluster_acc', loss_terms['cluster_acc']) # first select graph for each sample - graph_index = torch.argmax(self.mixture_selection.get_probs(idx), dim=0) + graph_index = torch.argmax( + self.mixture_selection.get_probs(idx), dim=0) # get the corresponding graph probabilities probs = self.graphs.get_adj_matrix(do_round=False)[graph_index] if self.aggregated_graph: @@ -315,7 +319,6 @@ def validation_step(self, batch, batch_idx): return loss_terms - def configure_optimizers(self): """Set the learning rates for different sets of parameters.""" modules = { @@ -347,7 +350,8 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0): batch_size = X_full.shape[0] # first select graph for each sample - graph_index = torch.argmax(self.mixture_selection.get_probs(idx), dim=0) + graph_index = torch.argmax( + self.mixture_selection.get_probs(idx), dim=0) # get the corresponding graph probabilities probs = self.graphs.get_adj_matrix(do_round=False)[graph_index] if self.aggregated_graph: @@ -360,6 +364,7 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0): def get_cluster_indices(self): true_cluster_idx = get_adj_matrix_id(self.adj_matrices_np)[1] - pred_cluster_idx = self.mixture_selection.get_mixture_indices(torch.arange(self.total_samples)) + pred_cluster_idx = self.mixture_selection.get_mixture_indices( + torch.arange(self.total_samples)) return true_cluster_idx, pred_cluster_idx diff --git a/src/model/RhinoTrainer.py b/src/model/RhinoTrainer.py index 35c991e..4f88b85 100644 --- a/src/model/RhinoTrainer.py +++ b/src/model/RhinoTrainer.py @@ -1,24 +1,17 @@ -import lightning.pytorch as pl -import torch.nn as nn import torch -import torch.nn.functional as F -import math -from src.modules.CausalDecoder import CausalDecoder import numpy as np +from sklearn.metrics import f1_score + from src.utils.loss_utils import dag_penalty_notears, temporal_graph_sparsity from src.utils.data_utils.data_format_utils import to_time_aggregated_graph_np, to_time_aggregated_scores_torch, zero_out_diag_np, zero_out_diag_torch -from src.utils.metrics_utils import compute_shd, get_off_diagonal +from src.utils.metrics_utils import compute_shd from src.model.BaseTrainer import BaseTrainer from src.modules.adjacency_matrices.TemporalAdjacencyMatrix import TemporalAdjacencyMatrix from src.modules.adjacency_matrices.TwoWayTemporalAdjacencyMatrix import TwoWayTemporalAdjacencyMatrix -from sklearn.metrics import f1_score, roc_auc_score - from src.training.auglag import AugLagLRConfig, AuglagLRCallback, AugLagLR, AugLagLossCalculator -class RhinoTrainer(BaseTrainer): - """ - """ +class RhinoTrainer(BaseTrainer): def __init__(self, full_dataset: np.array, @@ -29,20 +22,20 @@ def __init__(self, causal_decoder, tcsf, disable_inst: bool = False, - likelihood_loss: str = 'flow', - sparsity_factor: float = 20, + likelihood_loss: str = 'flow', + sparsity_factor: float = 20, num_workers: int = 16, - batch_size: int = 256, + batch_size: int = 256, matrix_temperature: float = 0.25, aggregated_graph: bool = False, ignore_self_connections: bool = False, threeway_graph_dist: bool = True, skip_auglag_epochs: int = 0, training_procedure: str = 'auglag', - training_config = None, - init_logits = [0, 0], - use_all_for_val = False, - shuffle = True): + training_config=None, + init_logits=[0, 0], + use_all_for_val=False, + shuffle=True): self.aggregated_graph = aggregated_graph self.ignore_self_connections = ignore_self_connections @@ -65,8 +58,8 @@ def __init__(self, print("Number of fragments:", self.num_fragments) print("Number of samples:", self.total_samples) - - assert likelihood_loss == 'mse' or likelihood_loss == 'flow' + + assert likelihood_loss in ['mse', 'flow'] self.likelihood_loss = likelihood_loss self.sparsity_factor = sparsity_factor @@ -82,14 +75,15 @@ def __init__(self, if training_config is None: self.training_config = AugLagLRConfig() if self.skip_auglag_epochs > 0: - print(f"Not performing augmented lagrangian optimization for the first {self.skip_auglag_epochs} epochs...") + print( + f"Not performing augmented lagrangian optimization for the first {self.skip_auglag_epochs} epochs...") self.disabled_epochs = range(self.skip_auglag_epochs) else: - self.disabled_epochs = None + self.disabled_epochs = None self.lr_scheduler = AugLagLR(config=self.training_config) self.loss_calc = AugLagLossCalculator(init_alpha=self.training_config.init_alpha, - init_rho=self.training_config.init_rho) - + init_rho=self.training_config.init_rho) + def initialize_graph(self): if self.threeway_graph_dist: self.adj_matrix = TemporalAdjacencyMatrix( @@ -104,12 +98,13 @@ def initialize_graph(self): tau_gumbel=self.matrix_temperature, init_logits=self.init_logits, disable_inst=self.disable_inst) + def forward(self): raise NotImplementedError - def compute_loss_terms(self, X_history: torch.Tensor, X_current: torch.Tensor, G: torch.Tensor): + def compute_loss_terms(self, X_history: torch.Tensor, x_current: torch.Tensor, G: torch.Tensor): - #******************* graph prior ********************* + # ******************* graph prior ********************* graph_sparsity_term = self.sparsity_factor * \ temporal_graph_sparsity(G) @@ -129,22 +124,22 @@ def compute_loss_terms(self, X_history: torch.Tensor, X_current: torch.Tensor, G batch_size = X_history.shape[0] expanded_G = G.unsqueeze(0).repeat(batch_size, 1, 1, 1) - X_input = torch.cat((X_history, X_current.unsqueeze(1)), dim=1) - X_pred = self.causal_decoder(X_input, expanded_G) + X_input = torch.cat((X_history, x_current.unsqueeze(1)), dim=1) + x_pred = self.causal_decoder(X_input, expanded_G) - mse_loss = self.compute_mse(X_current, X_pred) - mape_loss = self.compute_mape(X_current, X_pred) + mse_loss = self.compute_mse(x_current, x_pred) + mape_loss = self.compute_mape(x_current, x_pred) if self.likelihood_loss == 'mse': likelihood_term = mse_loss elif self.likelihood_loss == 'flow': - batch, num_nodes, data_dim = X_current.shape - log_prob = self.tcsf.log_prob(X_input=(X_current - X_pred).view(batch, num_nodes*data_dim), - X_history=X_history, + batch, num_nodes, data_dim = x_current.shape + log_prob = self.tcsf.log_prob(X_input=(x_current - x_pred).view(batch, num_nodes*data_dim), + X_history=X_history, A=expanded_G).sum(-1) likelihood_term = -torch.mean(log_prob) - + loss_terms = { 'graph_sparsity': graph_sparsity_term, 'dagness_penalty': dagness_penalty, @@ -157,38 +152,38 @@ def compute_loss_terms(self, X_history: torch.Tensor, X_current: torch.Tensor, G 'likelihood': likelihood_term } return loss_terms - - def compute_loss(self, X_history, X_current, idx): + + def compute_loss(self, X_history, x_current, idx): # sample G G = self.adj_matrix.sample_A() loss_terms = self.compute_loss_terms( X_history=X_history, - X_current=X_current, + x_current=x_current, G=G) total_loss = loss_terms['likelihood'] +\ - loss_terms['graph_prior'] +\ - loss_terms['graph_entropy'] - + loss_terms['graph_prior'] +\ + loss_terms['graph_entropy'] + return total_loss, loss_terms, None def training_step(self, batch, batch_idx): - - X_history, X_current, adj_matrix, idx, _ = batch - loss, loss_terms, _ = self.compute_loss(X_history, X_current, idx) - - loss = self.loss_calc(loss, loss_terms['dagness_penalty']/self.num_fragments) + + X_history, x_current, _, idx, _ = batch + loss, loss_terms, _ = self.compute_loss(X_history, x_current, idx) + + loss = self.loss_calc( + loss, loss_terms['dagness_penalty']/self.num_fragments) self.log_dict(loss_terms, on_epoch=True) - loss_terms['loss'] = loss + loss_terms['loss'] = loss return loss_terms - - def validation_func(self, X_history, X_current, adj_matrix, G, idx): + def validation_func(self, X_history, x_current, adj_matrix, G, idx): batch_size = X_history.shape[0] - loss, loss_terms, _ = self.compute_loss(X_history, X_current, idx) + loss, loss_terms, _ = self.compute_loss(X_history, x_current, idx) G = G.detach().cpu().numpy() adj_matrix = adj_matrix.detach().cpu().numpy() @@ -203,26 +198,26 @@ def validation_func(self, X_history, X_current, adj_matrix, G, idx): f1 = f1_score(adj_matrix.flatten(), G.flatten()) else: mask = adj_matrix != G - + shd_loss = np.sum(mask)/batch_size shd_inst = np.sum(mask[:, 0])/batch_size shd_lag = np.sum(mask[:, 1:])/batch_size - + # shd_loss, shd_inst, shd_lag = compute_shd(adj_matrix, G) tp = np.logical_and(adj_matrix == 1, adj_matrix == G) fp = np.logical_and(adj_matrix != 1, G == 1) fn = np.logical_and(adj_matrix != 0, G == 0) - - f1_inst = 2*np.sum(tp[:, 0]) / (2*np.sum(tp[:, 0]) + np.sum(fp[:, 0]) + np.sum(fn[:, 0])) - f1_lag = 2*np.sum(tp[:, 1:]) / (2*np.sum(tp[:, 1:]) + np.sum(fp[:, 1:]) + np.sum(fn[:, 1:])) - + + f1_inst = 2*np.sum(tp[:, 0]) / (2*np.sum(tp[:, 0]) + + np.sum(fp[:, 0]) + np.sum(fn[:, 0])) + f1_lag = 2*np.sum(tp[:, 1:]) / (2*np.sum(tp[:, 1:]) + + np.sum(fp[:, 1:]) + np.sum(fn[:, 1:])) + # f1_inst = f1_score(get_off_diagonal(adj_matrix[:, 0]).flatten(), get_off_diagonal(G[:, 0]).flatten()) # f1_lag = f1_score(adj_matrix[:, 1:].flatten(), G[:, 1:].flatten()) shd_loss = torch.Tensor([shd_loss]) shd_inst = torch.Tensor([shd_inst]) shd_lag = torch.Tensor([shd_lag]) - - self.true_graph = adj_matrix[0] if not self.aggregated_graph: loss_terms['shd_inst'] = shd_inst @@ -236,29 +231,28 @@ def validation_func(self, X_history, X_current, adj_matrix, G, idx): loss_terms['val_loss'] = loss - for key in loss_terms: - self.log(key, loss_terms[key]) + for key, item in loss_terms.items(): + self.log(key, item) return loss_terms def validation_step(self, batch, batch_idx): - X_history, X_current, adj_matrix, idx, _ = batch + X_history, x_current, adj_matrix, idx, _ = batch batch_size = X_history.shape[0] G = self.adj_matrix.sample_A() expanded_G = G.unsqueeze(0).repeat(batch_size, 1, 1, 1) - loss_terms = self.validation_func(X_history, X_current, adj_matrix, expanded_G, idx) - self.true_graph = adj_matrix[0] - + loss_terms = self.validation_func( + X_history, x_current, adj_matrix, expanded_G, idx) + probs = self.adj_matrix.get_adj_matrix(do_round=False) probs = probs.unsqueeze(0).repeat(batch_size, 1, 1, 1) if self.aggregated_graph: probs = to_time_aggregated_scores_torch(probs) if self.ignore_self_connections: probs = zero_out_diag_torch(probs) - - + return loss_terms def configure_optimizers(self): @@ -284,18 +278,20 @@ def configure_optimizers(self): assert module in check_modules, f"Module {module} not in module list" return torch.optim.Adam(parameter_list) - + def configure_callbacks(self): """Create a callback for the auglag callback.""" if self.training_procedure == 'auglag': return [AuglagLRCallback(self.lr_scheduler, log_auglag=True, disabled_epochs=self.disabled_epochs)] + return None # should not happen def predict_step(self, batch, batch_idx, dataloader_idx=0): X_full, adj_matrix, _ = batch batch_size = X_full.shape[0] - probs = self.adj_matrix.get_adj_matrix(do_round=False).unsqueeze(0).repeat(batch_size, 1, 1, 1) + probs = self.adj_matrix.get_adj_matrix( + do_round=False).unsqueeze(0).repeat(batch_size, 1, 1, 1) if self.aggregated_graph: probs = to_time_aggregated_scores_torch(probs) @@ -304,4 +300,3 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0): # G = torch.bernoulli(probs) G = (probs >= 0.5).long() return G, probs, adj_matrix - diff --git a/src/model/generate_model.py b/src/model/generate_model.py index b257402..cc7430a 100644 --- a/src/model/generate_model.py +++ b/src/model/generate_model.py @@ -2,58 +2,51 @@ from hydra.utils import instantiate from src.modules.TemporalConditionalSplineFlow import TemporalConditionalSplineFlow + def generate_model(cfg: DictConfig): - lag = cfg.lag num_nodes = cfg.num_nodes data_dim = cfg.data_dim num_workers = cfg.num_workers aggregated_graph = cfg.aggregated_graph - if cfg.model == 'pcmci' or cfg.model == 'varlingam' \ - or cfg.model == 'dynotears': - - trainer = instantiate(cfg.trainer, - num_workers=num_workers, - lag=lag, - num_nodes=num_nodes, - data_dim=data_dim, - aggregated_graph=aggregated_graph) - + if cfg.model in ['pcmci', 'varlingam', 'dynotears']: + trainer = instantiate(cfg.trainer, + num_workers=num_workers, + lag=lag, + num_nodes=num_nodes, + data_dim=data_dim, + aggregated_graph=aggregated_graph) else: - multi_graph = cfg.model == 'mcd' if multi_graph: num_graphs = cfg.trainer.num_graphs - if 'decoder' in cfg: # generate the decoder - if not multi_graph: - causal_decoder = instantiate(cfg.decoder, - lag=lag, - num_nodes=num_nodes, - data_dim=data_dim) + if not multi_graph: + causal_decoder = instantiate(cfg.decoder, + lag=lag, + num_nodes=num_nodes, + data_dim=data_dim) else: - causal_decoder = instantiate(cfg.decoder, - lag=lag, - num_nodes=num_nodes, - data_dim=data_dim, - num_graphs=num_graphs) - + causal_decoder = instantiate(cfg.decoder, + lag=lag, + num_nodes=num_nodes, + data_dim=data_dim, + num_graphs=num_graphs) if 'likelihood_loss' in cfg.trainer and cfg.trainer.likelihood_loss == 'flow': # create hypernet if not multi_graph: - hypernet = instantiate(cfg.hypernet, - lag=lag, - num_nodes=num_nodes, - data_dim=data_dim) + hypernet = instantiate(cfg.hypernet, + lag=lag, + num_nodes=num_nodes, + data_dim=data_dim) else: - hypernet = instantiate(cfg.hypernet, - lag=lag, - num_nodes=num_nodes, - data_dim=data_dim, - num_graphs=num_graphs) - + hypernet = instantiate(cfg.hypernet, + lag=lag, + num_nodes=num_nodes, + data_dim=data_dim, + num_graphs=num_graphs) tcsf = TemporalConditionalSplineFlow(hypernet=hypernet) else: hypernet = None @@ -62,27 +55,27 @@ def generate_model(cfg: DictConfig): # create auglag config if cfg.trainer.training_procedure == 'auglag': training_config = instantiate(cfg.auglag_config) - + if cfg.model == 'rhino': - trainer = instantiate(cfg.trainer, - num_workers=num_workers, - lag=lag, - num_nodes=num_nodes, - data_dim=data_dim, - causal_decoder=causal_decoder, - tcsf=tcsf, - training_config=training_config, - aggregated_graph=aggregated_graph) + trainer = instantiate(cfg.trainer, + num_workers=num_workers, + lag=lag, + num_nodes=num_nodes, + data_dim=data_dim, + causal_decoder=causal_decoder, + tcsf=tcsf, + training_config=training_config, + aggregated_graph=aggregated_graph) elif cfg.model == 'mcd': - trainer = instantiate(cfg.trainer, - num_workers=num_workers, - lag=lag, - num_nodes=num_nodes, - data_dim=data_dim, - num_graphs=num_graphs, - causal_decoder=causal_decoder, - tcsf=tcsf, - training_config=training_config, - aggregated_graph=aggregated_graph) - + trainer = instantiate(cfg.trainer, + num_workers=num_workers, + lag=lag, + num_nodes=num_nodes, + data_dim=data_dim, + num_graphs=num_graphs, + causal_decoder=causal_decoder, + tcsf=tcsf, + training_config=training_config, + aggregated_graph=aggregated_graph) + return trainer diff --git a/src/modules/CausalDecoder.py b/src/modules/CausalDecoder.py index 681ba6d..09f56c5 100644 --- a/src/modules/CausalDecoder.py +++ b/src/modules/CausalDecoder.py @@ -1,21 +1,21 @@ import lightning.pytorch as pl -import torch.nn as nn +from torch import nn import torch -import math from src.utils.torch_utils import generate_fully_connected + class CausalDecoder(pl.LightningModule): def __init__(self, - data_dim: int, - lag: int, + data_dim: int, + lag: int, num_nodes: int, - embedding_dim: int = None, + embedding_dim: int = None, skip_connection: bool = False, linear: bool = False ): super().__init__() - + if embedding_dim is None: embedding_dim = num_nodes * data_dim @@ -30,15 +30,13 @@ def __init__(self, torch.randn(self.lag + 1, self.num_nodes, self.embedding_dim, device=self.device) * 0.01 ), requires_grad=True) # shape (lag+1, num_nodes, embedding_dim) - - + input_dim = 2*self.embedding_dim self.nn_size = max(4 * num_nodes, self.embedding_dim, 64) - self.f = generate_fully_connected( input_dim=input_dim, - output_dim=num_nodes*data_dim, #potentially num_nodes + output_dim=num_nodes*data_dim, # potentially num_nodes hidden_dims=[self.nn_size, self.nn_size], non_linearity=nn.LeakyReLU, activation=nn.Identity, @@ -80,7 +78,7 @@ def forward(self, X_input: torch.Tensor, A: torch.Tensor, embeddings: torch.Tens assert (A.shape[0] == batch and A.shape[1] == lag+1 and A.shape[2] == num_nodes and A.shape[2] == num_nodes) - if embeddings == None: + if embeddings is None: E = self.embeddings.expand( X_input.shape[0], -1, -1, -1 ) @@ -94,14 +92,14 @@ def forward(self, X_input: torch.Tensor, A: torch.Tensor, embeddings: torch.Tens A_temp = A.flip([1]) # get the parents of X - X_sum = torch.einsum("blij,blio->bjo", A_temp, X_enc) # / num_nodes - + X_sum = torch.einsum("blij,blio->bjo", A_temp, + X_enc) # / num_nodes + X_sum = torch.cat([X_sum, E[:, 0, :, :]], dim=-1) # (batch, num_nodes, embedding_dim) # pass through f network to get the predictions - - self.group_mask = torch.eye(num_nodes*data_dim).to(self.device) - return torch.sum(self.f(X_sum)*self.group_mask, dim=-1).unsqueeze(-1) # (batch, num_nodes, data_dim) - else: - return torch.einsum("lij,blio->bjo", (self.w * A[0]).flip([0]), X_input) + group_mask = torch.eye(num_nodes*data_dim).to(self.device) + # (batch, num_nodes, data_dim) + return torch.sum(self.f(X_sum)*group_mask, dim=-1).unsqueeze(-1) + return torch.einsum("lij,blio->bjo", (self.w * A[0]).flip([0]), X_input) diff --git a/src/modules/LinearCausalGraph.py b/src/modules/LinearCausalGraph.py index 8b3b3c7..a7b34ba 100644 --- a/src/modules/LinearCausalGraph.py +++ b/src/modules/LinearCausalGraph.py @@ -1,10 +1,9 @@ import lightning.pytorch as pl import torch -import torch.distributions as td -import torch.nn.functional as F from torch import nn + class LinearCausalGraph(pl.LightningModule): def __init__( @@ -25,12 +24,11 @@ def __init__( ) self.I = torch.arange(input_dim) self.mask = torch.ones((self.lag+1, input_dim, input_dim)) - self.mask[0, self.I, self.I] = 0 + self.mask[0, self.I, self.I] = 0 self.input_dim = input_dim - + def get_w(self) -> torch.Tensor: """ Returns the matrix. Ensures that the instantaneous matrix has zero in the diagonals """ return self.w * self.mask.to(self.device) - diff --git a/src/modules/MixtureSelectionLogits.py b/src/modules/MixtureSelectionLogits.py index c385e1a..ab9c9b4 100644 --- a/src/modules/MixtureSelectionLogits.py +++ b/src/modules/MixtureSelectionLogits.py @@ -1,11 +1,12 @@ import lightning.pytorch as pl -import torch.nn as nn +from torch import nn import torch import torch.nn.functional as F import torch.distributions as td + class MixtureSelectionLogits(pl.LightningModule): - + def __init__( self, num_samples: int, @@ -18,21 +19,21 @@ def __init__( self.num_samples = num_samples self.graph_select_logits = nn.Parameter(( torch.ones(self.num_graphs, self.num_samples, - device=self.device) * 0.01 + device=self.device) * 0.01 ), requires_grad=True) self.tau = tau - + def manual_set_mixture_indices(self, idx, mixture_idx): """ Use this function to manually set the mixture index. Mainly used for diagnostic/ablative purposes """ - + with torch.no_grad(): self.graph_select_logits[:, idx] = -10 self.graph_select_logits[mixture_idx, idx] = 10 self.graph_select_logits.requires_grad_(False) - + def set_logits(self, idx, logits): """ Use this function to manually set the logits. @@ -43,10 +44,10 @@ def set_logits(self, idx, logits): def reset_parameters(self): with torch.no_grad(): - self.graph_select_logits[:] = torch.ones(self.num_graphs, - self.num_samples, - device=self.device) * 0.01 - + self.graph_select_logits[:] = torch.ones(self.num_graphs, + self.num_samples, + device=self.device) * 0.01 + def turn_off_grad(self): self.graph_select_logits.requires_grad_(False) @@ -55,13 +56,13 @@ def turn_on_grad(self): def get_probs(self, idx): return F.softmax(self.graph_select_logits[:, idx]/self.tau, dim=0) - + def get_mixture_indices(self, idx): return torch.argmax(self.graph_select_logits[:, idx], dim=0) - + def entropy(self, idx): logits = self.graph_select_logits[:, idx]/self.tau dist = td.Categorical(logits=logits.transpose(0, -1)) entropy = dist.entropy().sum() - return entropy/idx.shape[0] \ No newline at end of file + return entropy/idx.shape[0] diff --git a/src/modules/MultiCausalDecoder.py b/src/modules/MultiCausalDecoder.py index fac7d89..bf6f473 100644 --- a/src/modules/MultiCausalDecoder.py +++ b/src/modules/MultiCausalDecoder.py @@ -1,24 +1,25 @@ import lightning.pytorch as pl -import torch.nn as nn +from torch import nn import torch from src.modules.MultiEmbedding import MultiEmbedding from src.utils.torch_utils import generate_fully_connected + class MultiCausalDecoder(pl.LightningModule): def __init__(self, - data_dim: int, - lag: int, - num_nodes: int, + data_dim: int, + lag: int, + num_nodes: int, num_graphs: int, embedding_dim: int = None, skip_connection: bool = False, linear: bool = False, dropout_p: float = 0.0 ): - + super().__init__() - + if embedding_dim is not None: self.embedding_dim = embedding_dim else: @@ -38,7 +39,7 @@ def __init__(self, self.f = generate_fully_connected( input_dim=input_dim, - output_dim=num_nodes*data_dim, #potentially num_nodes + output_dim=num_nodes*data_dim, # potentially num_nodes hidden_dims=[self.nn_size, self.nn_size], non_linearity=nn.LeakyReLU, activation=nn.Identity, @@ -62,7 +63,7 @@ def __init__(self, lag=self.lag, num_graphs=self.num_graphs, embedding_dim=self.embedding_dim) - + else: self.w = nn.Parameter( torch.randn(self.num_graphs, self.lag+1, self.num_nodes, self.num_nodes, device=self.device)*0.5, requires_grad=True @@ -86,14 +87,15 @@ def forward(self, X_input: torch.Tensor, A: torch.Tensor): # reshape X to the correct shape A = A.unsqueeze(0).expand((batch, -1, -1, -1, -1)) E = E.unsqueeze(0).expand((batch, -1, -1, -1, -1)) - X_input = X_input.unsqueeze(1).expand((-1, self.num_graphs, -1, -1, -1)) + X_input = X_input.unsqueeze(1).expand( + (-1, self.num_graphs, -1, -1, -1)) # ensure we have the correct shape - assert (A.shape[0] == batch and A.shape[1] == self.num_graphs and - A.shape[2] == lag + 1 and A.shape[3] == num_nodes and + assert (A.shape[0] == batch and A.shape[1] == self.num_graphs and + A.shape[2] == lag + 1 and A.shape[3] == num_nodes and A.shape[4] == num_nodes) - assert (E.shape[0] == batch and E.shape[1] == self.num_graphs - and E.shape[2] == lag+1 and E.shape[3] == num_nodes + assert (E.shape[0] == batch and E.shape[1] == self.num_graphs + and E.shape[2] == lag+1 and E.shape[3] == num_nodes and E.shape[4] == self.embedding_dim) X_in = torch.cat((X_input, E), dim=-1) @@ -102,15 +104,14 @@ def forward(self, X_input: torch.Tensor, A: torch.Tensor): A_temp = A.flip([2]) # get the parents of X X_sum = torch.einsum("bnlij,bnlio->bnjo", A_temp, X_enc) - + X_sum = torch.cat([X_sum, E[:, :, 0, :, :]], dim=-1) # (batch, num_graphs, num_nodes, embedding_dim) # pass through f network to get the predictions - self.group_mask = torch.eye(num_nodes*data_dim).to(self.device) + group_mask = torch.eye(num_nodes*data_dim).to(self.device) # (batch, num_graphs, num_nodes, data_dim) - return torch.sum(self.f(X_sum)*self.group_mask, dim=-1).unsqueeze(-1) + return torch.sum(self.f(X_sum)*group_mask, dim=-1).unsqueeze(-1) - else: - return torch.einsum("klij,blio->bkjo", (self.w * A).flip([1]), X_input) - # return self.f(X_sum) \ No newline at end of file + return torch.einsum("klij,blio->bkjo", (self.w * A).flip([1]), X_input) + # return self.f(X_sum) diff --git a/src/modules/MultiEmbedding.py b/src/modules/MultiEmbedding.py index 8b5db85..8ddb054 100644 --- a/src/modules/MultiEmbedding.py +++ b/src/modules/MultiEmbedding.py @@ -1,9 +1,10 @@ import lightning.pytorch as pl -import torch.nn as nn +from torch import nn import torch + class MultiEmbedding(pl.LightningModule): - + def __init__( self, num_nodes: int, @@ -21,15 +22,15 @@ def __init__( self.embedding_dim = embedding_dim self.lag_embeddings = nn.Parameter(( - torch.randn(self.num_graphs, self.lag, self.num_nodes, - self.embedding_dim, device=self.device) * 0.01 - ), requires_grad=True) + torch.randn(self.num_graphs, self.lag, self.num_nodes, + self.embedding_dim, device=self.device) * 0.01 + ), requires_grad=True) self.inst_embeddings = nn.Parameter(( - torch.randn(self.num_graphs, 1, self.num_nodes, - self.embedding_dim, device=self.device) * 0.01 - ), requires_grad=True) - + torch.randn(self.num_graphs, 1, self.num_nodes, + self.embedding_dim, device=self.device) * 0.01 + ), requires_grad=True) + def turn_off_inst_grad(self): self.inst_embeddings.requires_grad_(False) @@ -38,5 +39,3 @@ def turn_on_inst_grad(self): def get_embeddings(self): return torch.cat((self.inst_embeddings, self.lag_embeddings), dim=1) - - \ No newline at end of file diff --git a/src/modules/MultiLinearCausalGraph.py b/src/modules/MultiLinearCausalGraph.py index e27d894..03a5242 100644 --- a/src/modules/MultiLinearCausalGraph.py +++ b/src/modules/MultiLinearCausalGraph.py @@ -1,10 +1,9 @@ import lightning.pytorch as pl import torch -import torch.distributions as td -import torch.nn.functional as F from torch import nn + class MultiLinearCausalGraph(pl.LightningModule): def __init__( @@ -26,13 +25,13 @@ def __init__( torch.randn(self.num_graphs, self.lag+1, input_dim, input_dim, device=self.device)*0.5, requires_grad=True ) self.I = torch.arange(input_dim) - self.mask = torch.ones((self.num_graphs, self.lag+1, input_dim, input_dim)) - self.mask[:, 0, self.I, self.I] = 0 + self.mask = torch.ones( + (self.num_graphs, self.lag+1, input_dim, input_dim)) + self.mask[:, 0, self.I, self.I] = 0 self.input_dim = input_dim - + def get_w(self) -> torch.Tensor: """ Returns the matrix. Ensures that the instantaneous matrix has zero in the diagonals """ return self.w * self.mask.to(self.device) - diff --git a/src/modules/MultiTemporalHyperNet.py b/src/modules/MultiTemporalHyperNet.py index f02a302..44fc58f 100644 --- a/src/modules/MultiTemporalHyperNet.py +++ b/src/modules/MultiTemporalHyperNet.py @@ -1,10 +1,11 @@ +from typing import Dict, Tuple import lightning.pytorch as pl -import torch.nn as nn +from torch import nn import torch -from typing import Dict, List, Optional, Tuple, Type from src.modules.MultiEmbedding import MultiEmbedding from src.utils.torch_utils import generate_fully_connected + class MultiTemporalHyperNet(pl.LightningModule): def __init__(self, @@ -18,7 +19,7 @@ def __init__(self, num_bins: int = 8, dropout_p: float = 0.0 ): - + super().__init__() if embedding_dim is not None: @@ -47,7 +48,7 @@ def __init__(self, (self.num_bins - 1), self.num_bins, ] # this is for linear order conditional spline flow - + self.total_param = sum(self.param_dim) input_dim = 2*self.embedding_dim @@ -55,7 +56,7 @@ def __init__(self, self.f = generate_fully_connected( input_dim=input_dim, - output_dim=self.total_param, #potentially num_nodes + output_dim=self.total_param, # potentially num_nodes hidden_dims=[self.nn_size, self.nn_size], non_linearity=nn.LeakyReLU, activation=nn.Identity, @@ -79,7 +80,7 @@ def __init__(self, lag=self.lag, num_graphs=self.num_graphs, embedding_dim=self.embedding_dim) - + def forward(self, X: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, ...]: """ Args: @@ -96,11 +97,11 @@ def forward(self, X: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, ...]: A = X["A"] X_in = X["X"] - + E = self.th_embeddings.get_embeddings() - #X["embeddings"] + # X["embeddings"] - batch, lag, num_nodes, data_dim = X_in.shape + batch, lag, num_nodes, _ = X_in.shape # reshape X to the correct shape A = A.unsqueeze(0).expand((batch, -1, -1, -1, -1)) @@ -108,11 +109,11 @@ def forward(self, X: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, ...]: X_in = X_in.unsqueeze(1).expand((-1, self.num_graphs, -1, -1, -1)) # ensure we have the correct shape - assert (A.shape[0] == batch and A.shape[1] == self.num_graphs and - A.shape[2] == lag + 1 and A.shape[3] == num_nodes and + assert (A.shape[0] == batch and A.shape[1] == self.num_graphs and + A.shape[2] == lag + 1 and A.shape[3] == num_nodes and A.shape[4] == num_nodes) - assert (E.shape[0] == batch and E.shape[1] == self.num_graphs - and E.shape[2] == lag+1 and E.shape[3] == num_nodes + assert (E.shape[0] == batch and E.shape[1] == self.num_graphs + and E.shape[2] == lag+1 and E.shape[3] == num_nodes and E.shape[4] == self.embedding_dim) # shape [batch_size, num_graphs, lag, num_nodes, embedding_size] @@ -129,7 +130,7 @@ def forward(self, X: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, ...]: # (batch, num_graphs, num_nodes, embedding_dim) A_temp = A[:, :, 1:].flip([2]) - X_sum = torch.einsum("bnlij,bnlio->bnjo", A_temp, X_enc) #/ num_nodes + X_sum = torch.einsum("bnlij,bnlio->bnjo", A_temp, X_enc) # / num_nodes X_sum = torch.cat((X_sum, E[:, :, 0, :, :]), dim=-1) diff --git a/src/modules/TemporalConditionalSplineFlow.py b/src/modules/TemporalConditionalSplineFlow.py index 1aa6248..b84dd75 100644 --- a/src/modules/TemporalConditionalSplineFlow.py +++ b/src/modules/TemporalConditionalSplineFlow.py @@ -1,17 +1,11 @@ import lightning.pytorch as pl -import torch.nn as nn +from torch import nn import torch import pyro.distributions as distrib -import torch.distributions as td -from pyro.distributions.conditional import ConditionalTransform -from pyro.distributions.transforms import ComposeTransform -from pyro.distributions.transforms.spline import ConditionalSpline, Spline - -from src.modules.TemporalHyperNet import TemporalHyperNet +from pyro.distributions.transforms.spline import ConditionalSpline from pyro.distributions import constraints from pyro.distributions.torch_transform import TransformModule -from torch import nn class AffineDiagonalPyro(TransformModule): @@ -65,11 +59,11 @@ def __init__(self, self.num_bins = self.hypernet.num_bins self.order = self.hypernet.order - def log_prob(self, - X_input: torch.Tensor, - X_history: torch.Tensor, + def log_prob(self, + X_input: torch.Tensor, + X_history: torch.Tensor, A: torch.Tensor, - embeddings: torch.Tensor=None): + embeddings: torch.Tensor = None): """ Args: X_input: input data of shape (batch, num_nodes, data_dim) @@ -80,10 +74,10 @@ def log_prob(self, assert len(X_history.shape) == 4 - batch, lag, num_nodes, data_dim = X_history.shape + _, _, num_nodes, data_dim = X_history.shape # if not self.trainable_embeddings: - self.transform = nn.ModuleList( + transform = nn.ModuleList( [ ConditionalSpline( self.hypernet, input_dim=num_nodes*data_dim, count_bins=self.num_bins, order=self.order, bound=5.0 @@ -95,7 +89,7 @@ def log_prob(self, # AffineDiagonalPyro(input_dim=self.num_nodes*self.data_dim) ] ) - self.base_dist = distrib.Normal( + base_dist = distrib.Normal( torch.zeros(num_nodes*data_dim, device=self.device), torch.ones( num_nodes*data_dim, device=self.device) ) @@ -103,17 +97,17 @@ def log_prob(self, context_dict = {"X": X_history, "A": A, "embeddings": embeddings} flow_dist = distrib.ConditionalTransformedDistribution( - self.base_dist, self.transform).condition(context_dict) + base_dist, transform).condition(context_dict) return flow_dist.log_prob(X_input) - def sample(self, - N_samples: int, - X_history: torch.Tensor, + def sample(self, + N_samples: int, + X_history: torch.Tensor, W: torch.Tensor, embeddings: torch.Tensor): assert len(X_history.shape) == 4 - batch, lag, num_nodes, data_dim = X_history.shape + batch, _, num_nodes, data_dim = X_history.shape base_dist = distrib.Normal( torch.zeros(num_nodes*data_dim, device=self.device), torch.ones( diff --git a/src/modules/TemporalHyperNet.py b/src/modules/TemporalHyperNet.py index 71101ae..bf74e2c 100644 --- a/src/modules/TemporalHyperNet.py +++ b/src/modules/TemporalHyperNet.py @@ -1,8 +1,7 @@ +from typing import Dict, Tuple import lightning.pytorch as pl -import torch.nn as nn +from torch import nn import torch -from typing import Dict, List, Optional, Tuple, Type -import math from src.utils.torch_utils import generate_fully_connected class TemporalHyperNet(pl.LightningModule): @@ -26,7 +25,7 @@ def __init__(self, self.order = order self.num_bins = num_bins self.num_nodes = num_nodes - + if self.order == "quadratic": self.param_dim = [ self.num_bins, @@ -40,14 +39,14 @@ def __init__(self, (self.num_bins - 1), self.num_bins, ] # this is for linear order conditional spline flow - + self.total_param = sum(self.param_dim) input_dim = 2*self.embedding_dim self.nn_size = max(4 * num_nodes, self.embedding_dim, 64) self.f = generate_fully_connected( input_dim=input_dim, - output_dim=self.total_param, #potentially num_nodes + output_dim=self.total_param, # potentially num_nodes hidden_dims=[self.nn_size, self.nn_size], non_linearity=nn.LeakyReLU, activation=nn.Identity, @@ -69,10 +68,9 @@ def __init__(self, self.embeddings = nn.Parameter(( torch.randn(self.lag + 1, self.num_nodes, - self.embedding_dim, device=self.device) * 0.01 + self.embedding_dim, device=self.device) * 0.01 ), requires_grad=True) # shape (lag+1, num_nodes, embedding_dim) - def forward(self, X: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, ...]: """ Args: @@ -90,13 +88,13 @@ def forward(self, X: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, ...]: A = X["A"] X_in = X["X"] embeddings = X["embeddings"] - batch, lag, num_nodes, data_dim = X_in.shape + batch, lag, num_nodes, _ = X_in.shape # ensure we have the correct shape assert (A.shape[0] == batch and A.shape[1] == lag + 1 and A.shape[2] == num_nodes and A.shape[3] == num_nodes) - if embeddings == None: + if embeddings is None: E = self.embeddings.expand( X_in.shape[0], -1, -1, -1 ) @@ -115,7 +113,7 @@ def forward(self, X: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, ...]: # (batch, num_nodes, embedding_dim) A_temp = A[:, 1:].flip([1]) - X_sum = torch.einsum("blij,blio->bjo", A_temp, X_enc) #/ num_nodes + X_sum = torch.einsum("blij,blio->bjo", A_temp, X_enc) # / num_nodes X_sum = torch.cat((X_sum, E[..., 0, :, :]), dim=-1) diff --git a/src/modules/adjacency_matrices/AdjMatrix.py b/src/modules/adjacency_matrices/AdjMatrix.py index a42cabe..a21f37e 100644 --- a/src/modules/adjacency_matrices/AdjMatrix.py +++ b/src/modules/adjacency_matrices/AdjMatrix.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod import torch + class AdjMatrix(ABC): """ Adjacency matrix interface for DECI @@ -25,4 +26,4 @@ def sample_A(self) -> torch.Tensor: """ Returns the adjacency matrix. """ - raise NotImplementedError() \ No newline at end of file + raise NotImplementedError() diff --git a/src/modules/adjacency_matrices/MultiTemporalAdjacencyMatrix.py b/src/modules/adjacency_matrices/MultiTemporalAdjacencyMatrix.py index 0e8e33e..a0f517e 100644 --- a/src/modules/adjacency_matrices/MultiTemporalAdjacencyMatrix.py +++ b/src/modules/adjacency_matrices/MultiTemporalAdjacencyMatrix.py @@ -7,6 +7,7 @@ from torch import nn from src.modules.adjacency_matrices.AdjMatrix import AdjMatrix + class MultiTemporalAdjacencyMatrix(pl.LightningModule, AdjMatrix): def __init__( self, @@ -27,25 +28,26 @@ def __init__( assert lag > 0 self.num_graphs = num_graphs self.disable_inst = disable_inst - + if self.threeway: self.logits_inst = nn.Parameter( - torch.zeros((3, self.num_graphs, (num_nodes * (num_nodes - 1)) // 2), - device=self.device), - requires_grad=True + torch.zeros((3, self.num_graphs, (num_nodes * (num_nodes - 1)) // 2), + device=self.device), + requires_grad=True ) self.lower_idxs = torch.unbind( - torch.tril_indices(self.num_nodes, self.num_nodes, offset=-1, device=self.device), 0 + torch.tril_indices(self.num_nodes, self.num_nodes, + offset=-1, device=self.device), 0 ) else: self.logits_inst = nn.Parameter( - torch.zeros((2, self.num_graphs, num_nodes, num_nodes), - device=self.device), - requires_grad=True + torch.zeros((2, self.num_graphs, num_nodes, num_nodes), + device=self.device), + requires_grad=True ) - - self.logits_lag = nn.Parameter(torch.zeros((2, self.num_graphs, lag, num_nodes, num_nodes), - device=self.device), + + self.logits_lag = nn.Parameter(torch.zeros((2, self.num_graphs, lag, num_nodes, num_nodes), + device=self.device), requires_grad=True) self.init_logits = init_logits # Set the init_logits if not None @@ -55,7 +57,7 @@ def __init__( else: self.logits_inst.data[1, :, ...] = self.init_logits[0] self.logits_lag.data[0, :, ...] = self.init_logits[1] - + def zero_out_diagonal(self, matrix: torch.Tensor): # matrix: (num_graphs, num_nodes, num_nodes) N = matrix.shape[1] @@ -63,7 +65,7 @@ def zero_out_diagonal(self, matrix: torch.Tensor): matrix = matrix.clone() matrix[:, I, I] = 0 return matrix - + def _triangular_vec_to_matrix(self, vec): """ Given an array of shape (k, N, n(n-1)/2) where k in {2, 3}, creates a matrix of shape @@ -71,39 +73,38 @@ def _triangular_vec_to_matrix(self, vec): triangular is filled from vec[1, :]. """ N = vec.shape[1] - output = torch.zeros((N, self.num_nodes, self.num_nodes), device=self.device) + output = torch.zeros( + (N, self.num_nodes, self.num_nodes), device=self.device) output[:, self.lower_idxs[0], self.lower_idxs[1]] = vec[0, :, ...] output[:, self.lower_idxs[1], self.lower_idxs[0]] = vec[1, :, ...] return output - + def get_adj_matrix(self, do_round: bool = False) -> torch.Tensor: """ Returns the adjacency matrix. """ - probs = torch.zeros((self.num_graphs, self.lag + 1, self.num_nodes, self.num_nodes), + probs = torch.zeros((self.num_graphs, self.lag + 1, self.num_nodes, self.num_nodes), device=self.device) - + if not self.disable_inst: - inst_probs = F.softmax(self.logits_inst, dim=0) + inst_probs = F.softmax(self.logits_inst, dim=0) if self.threeway: # (3, n(n-1)/2) probabilities inst_probs = self._triangular_vec_to_matrix(inst_probs) else: inst_probs = self.zero_out_diagonal(inst_probs[1, ...]) - + # Generate simultaneous adj matrix # shape (input_dim, input_dim) - probs[:, 0, ...] = inst_probs - - + probs[:, 0, ...] = inst_probs # Generate lagged adj matrix # shape (lag, input_dim, input_dim) - probs[:, 1:, ...] = F.softmax(self.logits_lag, dim=0)[1, ...] + probs[:, 1:, ...] = F.softmax(self.logits_lag, dim=0)[1, ...] if do_round: return probs.round() - else: - return probs + + return probs def entropy(self) -> torch.Tensor: """ @@ -112,22 +113,26 @@ def entropy(self) -> torch.Tensor: if not self.disable_inst: if self.threeway: - dist = td.Categorical(logits=self.logits_inst[:, :].transpose(0, -1)) + dist = td.Categorical( + logits=self.logits_inst[:, :].transpose(0, -1)) entropies_inst = dist.entropy().sum() else: - dist = td.Categorical(logits=self.logits_inst[1, ...] - self.logits_inst[0, ...]) + dist = td.Categorical( + logits=self.logits_inst[1, ...] - self.logits_inst[0, ...]) I = torch.arange(self.num_nodes) - dist_diag = td.Categorical(logits=self.logits_inst[1, :, I, I] - self.logits_inst[0, :, I, I]) + dist_diag = td.Categorical( + logits=self.logits_inst[1, :, I, I] - self.logits_inst[0, :, I, I]) entropies = dist.entropy() diag_entropy = dist_diag.entropy() entropies_inst = entropies.sum() - diag_entropy.sum() else: entropies_inst = 0 - dist_lag = td.Independent(td.Bernoulli(logits=self.logits_lag[1, :] - self.logits_lag[0, :]), 3) + dist_lag = td.Independent(td.Bernoulli( + logits=self.logits_lag[1, :] - self.logits_lag[0, :]), 3) entropies_lag = dist_lag.entropy().sum() - - return (entropies_lag + entropies_inst) + + return entropies_lag + entropies_inst def sample_A(self) -> torch.Tensor: """ @@ -143,25 +148,26 @@ def sample_A(self) -> torch.Tensor: if self.threeway: # Sample instantaneous adj matrix adj_sample[:, 0, ...] = self._triangular_vec_to_matrix( - F.gumbel_softmax(self.logits_inst, - tau=self.tau_gumbel, - hard=True, - dim=0) + F.gumbel_softmax(self.logits_inst, + tau=self.tau_gumbel, + hard=True, + dim=0) ) # shape (N, input_dim, input_dim) else: - sample = F.gumbel_softmax(self.logits_inst, tau=self.tau_gumbel, hard=True, dim=0)[1, ...] + sample = F.gumbel_softmax( + self.logits_inst, tau=self.tau_gumbel, hard=True, dim=0)[1, ...] adj_sample[:, 0, ...] = self.zero_out_diagonal(sample) - + # Sample lagged adj matrix # shape (N, lag, input_dim, input_dim) - adj_sample[:, 1:, ...] = F.gumbel_softmax(self.logits_lag, - tau=self.tau_gumbel, - hard=True, + adj_sample[:, 1:, ...] = F.gumbel_softmax(self.logits_lag, + tau=self.tau_gumbel, + hard=True, dim=0)[1, ...] return adj_sample def turn_off_inst_grad(self): self.logits_inst.requires_grad_(False) - + def turn_on_inst_grad(self): - self.logits_inst.requires_grad_(True) \ No newline at end of file + self.logits_inst.requires_grad_(True) diff --git a/src/modules/adjacency_matrices/TemporalAdjacencyMatrix.py b/src/modules/adjacency_matrices/TemporalAdjacencyMatrix.py index 6eb98cd..bbc1ceb 100644 --- a/src/modules/adjacency_matrices/TemporalAdjacencyMatrix.py +++ b/src/modules/adjacency_matrices/TemporalAdjacencyMatrix.py @@ -1,9 +1,11 @@ +from typing import List, Optional + import torch import torch.distributions as td import torch.nn.functional as F from torch import nn -from src.modules.adjacency_matrices.ThreeWayGraphDist import ThreeWayGraphDist -from typing import List, Optional +from src.modules.adjacency_matrices.ThreeWayGraphDist import ThreeWayGraphDist + class TemporalAdjacencyMatrix(ThreeWayGraphDist): """ @@ -41,7 +43,8 @@ def __init__( self.lag = lag # Assertion lag > 0 assert lag > 0 - self.logits_lag = nn.Parameter(torch.zeros((2, lag, input_dim, input_dim), device=self.device), requires_grad=True) + self.logits_lag = nn.Parameter(torch.zeros( + (2, lag, input_dim, input_dim), device=self.device), requires_grad=True) self.init_logits = init_logits self.disable_inst = disable_inst # Set the init_logits if not None @@ -60,20 +63,24 @@ def get_adj_matrix(self, do_round: bool = False) -> torch.Tensor: """ # Create the temporal adj matrix - probs = torch.zeros(self.lag + 1, self.input_dim, self.input_dim, device=self.device) + probs = torch.zeros(self.lag + 1, self.input_dim, + self.input_dim, device=self.device) # Generate simultaneous adj matrix if not self.disable_inst: - probs[0, ...] = super().get_adj_matrix(do_round=do_round) # shape (input_dim, input_dim) + probs[0, ...] = super().get_adj_matrix( + do_round=do_round) # shape (input_dim, input_dim) # Generate lagged adj matrix - probs[1:, ...] = F.softmax(self.logits_lag, dim=0)[1, ...] # shape (lag, input_dim, input_dim) + probs[1:, ...] = F.softmax(self.logits_lag, dim=0)[ + 1, ...] # shape (lag, input_dim, input_dim) if do_round: return probs.round() - else: - return probs + + return probs def entropy(self) -> torch.Tensor: """ - This computes the entropy of the variational distribution. This can be done by (1) compute the entropy of instantaneous adj matrix(categorical, same as ThreeWayGraphDist), + This computes the entropy of the variational distribution. + This can be done by (1) compute the entropy of instantaneous adj matrix(categorical, same as ThreeWayGraphDist), (2) compute the entropy of lagged adj matrix (Bernoulli dist), and (3) add them together. """ # Entropy for instantaneous dist, call super().entropy @@ -83,11 +90,12 @@ def entropy(self) -> torch.Tensor: entropies_inst = 0 # Entropy for lagged dist # batch_shape [lag], event_shape [num_nodes, num_nodes] - - dist_lag = td.Independent(td.Bernoulli(logits=self.logits_lag[1, ...] - self.logits_lag[0, ...]), 2) + + dist_lag = td.Independent(td.Bernoulli( + logits=self.logits_lag[1, ...] - self.logits_lag[0, ...]), 2) entropies_lag = dist_lag.entropy().sum() # entropies_lag = dist_lag.entropy().mean() - + return entropies_lag + entropies_inst def sample_A(self) -> torch.Tensor: @@ -104,13 +112,11 @@ def sample_A(self) -> torch.Tensor: # Sample instantaneous adj matrix if not self.disable_inst: adj_sample[0, ...] = self._triangular_vec_to_matrix( - F.gumbel_softmax(self.logits, tau=self.tau_gumbel, hard=True, dim=0) + F.gumbel_softmax( + self.logits, tau=self.tau_gumbel, hard=True, dim=0) ) # shape (input_dim, input_dim) # Sample lagged adj matrix adj_sample[1:, ...] = F.gumbel_softmax(self.logits_lag, tau=self.tau_gumbel, hard=True, dim=0)[ 1, ... ] # shape (lag, input_dim, input_dim) return adj_sample - - - \ No newline at end of file diff --git a/src/modules/adjacency_matrices/ThreeWayGraphDist.py b/src/modules/adjacency_matrices/ThreeWayGraphDist.py index dae2b21..5c7bac1 100644 --- a/src/modules/adjacency_matrices/ThreeWayGraphDist.py +++ b/src/modules/adjacency_matrices/ThreeWayGraphDist.py @@ -6,6 +6,7 @@ from torch import nn from src.modules.adjacency_matrices.AdjMatrix import AdjMatrix + class ThreeWayGraphDist(AdjMatrix, pl.LightningModule): """ An alternative variational distribution for graph edges. For each pair of nodes x_i and x_j @@ -38,7 +39,8 @@ def __init__( self.tau_gumbel = tau_gumbel self.input_dim = input_dim self.lower_idxs = torch.unbind( - torch.tril_indices(self.input_dim, self.input_dim, offset=-1, device=self.device), 0 + torch.tril_indices(self.input_dim, self.input_dim, + offset=-1, device=self.device), 0 ) def _triangular_vec_to_matrix(self, vec): @@ -47,7 +49,8 @@ def _triangular_vec_to_matrix(self, vec): (n, n) where the lower triangular is filled from vec[0, :] and the upper triangular is filled from vec[1, :]. """ - output = torch.zeros((self.input_dim, self.input_dim), device=self.device) + output = torch.zeros( + (self.input_dim, self.input_dim), device=self.device) output[self.lower_idxs[0], self.lower_idxs[1]] = vec[0, ...] output[self.lower_idxs[1], self.lower_idxs[0]] = vec[1, ...] return output @@ -60,8 +63,7 @@ def get_adj_matrix(self, do_round: bool = False) -> torch.Tensor: out_probs = self._triangular_vec_to_matrix(probs) if do_round: return out_probs.round() - else: - return out_probs + return out_probs def entropy(self) -> torch.Tensor: """ @@ -80,6 +82,6 @@ def sample_A(self) -> torch.Tensor: V1: Returns one sample to be used for the whole batch. """ - sample = F.gumbel_softmax(self.logits, tau=self.tau_gumbel, hard=True, dim=0) # (3, n(n-1)/2) binary + sample = F.gumbel_softmax( + self.logits, tau=self.tau_gumbel, hard=True, dim=0) # (3, n(n-1)/2) binary return self._triangular_vec_to_matrix(sample) - diff --git a/src/modules/adjacency_matrices/TwoWayGraphDist.py b/src/modules/adjacency_matrices/TwoWayGraphDist.py index a8d3377..a4d6222 100644 --- a/src/modules/adjacency_matrices/TwoWayGraphDist.py +++ b/src/modules/adjacency_matrices/TwoWayGraphDist.py @@ -5,6 +5,7 @@ from torch import nn from src.modules.adjacency_matrices.AdjMatrix import AdjMatrix + class TwoWayGraphDist(AdjMatrix, pl.LightningModule): """ Sampling is performed with `torch.gumbel_softmax(..., hard=True)` to give @@ -30,7 +31,7 @@ def __init__( ) self.tau_gumbel = tau_gumbel self.input_dim = input_dim - + def zero_out_diagonal(self, matrix: torch.Tensor): # matrix: (num_nodes, num_nodes) N = matrix.shape[0] @@ -38,7 +39,7 @@ def zero_out_diagonal(self, matrix: torch.Tensor): matrix = matrix.clone() matrix[I, I] = 0 return matrix - + def get_adj_matrix(self, do_round: bool = False) -> torch.Tensor: """ Returns the adjacency matrix of edge probabilities. @@ -49,8 +50,7 @@ def get_adj_matrix(self, do_round: bool = False) -> torch.Tensor: if do_round: return probs.round() - else: - return probs + return probs def entropy(self) -> torch.Tensor: """ @@ -58,7 +58,8 @@ def entropy(self) -> torch.Tensor: """ dist = td.Categorical(logits=self.logits[1, ...] - self.logits[0, ...]) I = torch.arange(self.input_dim) - dist_diag = td.Categorical(logits=self.logits[1, I, I] - self.logits[0, I, I]) + dist_diag = td.Categorical( + logits=self.logits[1, I, I] - self.logits[0, I, I]) entropies = dist.entropy() diag_entropy = dist_diag.entropy() return entropies.sum() - diag_entropy.sum() @@ -72,5 +73,6 @@ def sample_A(self) -> torch.Tensor: V1: Returns one sample to be used for the whole batch. """ - sample = F.gumbel_softmax(self.logits, tau=self.tau_gumbel, hard=True, dim=0)[1, ...] - return self.zero_out_diagonal(sample) \ No newline at end of file + sample = F.gumbel_softmax( + self.logits, tau=self.tau_gumbel, hard=True, dim=0)[1, ...] + return self.zero_out_diagonal(sample) diff --git a/src/modules/adjacency_matrices/TwoWayTemporalAdjacencyMatrix.py b/src/modules/adjacency_matrices/TwoWayTemporalAdjacencyMatrix.py index 2a1446b..8bb63b2 100644 --- a/src/modules/adjacency_matrices/TwoWayTemporalAdjacencyMatrix.py +++ b/src/modules/adjacency_matrices/TwoWayTemporalAdjacencyMatrix.py @@ -1,9 +1,11 @@ +from typing import List, Optional + import torch import torch.distributions as td import torch.nn.functional as F from torch import nn from src.modules.adjacency_matrices.TwoWayGraphDist import TwoWayGraphDist -from typing import List, Optional + class TwoWayTemporalAdjacencyMatrix(TwoWayGraphDist): """ @@ -38,7 +40,8 @@ def __init__( self.lag = lag # Assertion lag > 0 assert lag > 0 - self.logits_lag = nn.Parameter(torch.zeros((2, lag, input_dim, input_dim), device=self.device), requires_grad=True) + self.logits_lag = nn.Parameter(torch.zeros( + (2, lag, input_dim, input_dim), device=self.device), requires_grad=True) self.init_logits = init_logits self.disable_inst = disable_inst # Set the init_logits if not None @@ -57,20 +60,24 @@ def get_adj_matrix(self, do_round: bool = False) -> torch.Tensor: """ # Create the temporal adj matrix - probs = torch.zeros(self.lag + 1, self.input_dim, self.input_dim, device=self.device) + probs = torch.zeros(self.lag + 1, self.input_dim, + self.input_dim, device=self.device) # Generate simultaneous adj matrix if not self.disable_inst: - probs[0, ...] = super().get_adj_matrix(do_round=do_round) # shape (input_dim, input_dim) + probs[0, ...] = super().get_adj_matrix( + do_round=do_round) # shape (input_dim, input_dim) # Generate lagged adj matrix - probs[1:, ...] = F.softmax(self.logits_lag, dim=0)[1, ...] # shape (lag, input_dim, input_dim) + probs[1:, ...] = F.softmax(self.logits_lag, dim=0)[ + 1, ...] # shape (lag, input_dim, input_dim) if do_round: return probs.round() - else: - return probs + + return probs def entropy(self) -> torch.Tensor: """ - This computes the entropy of the variational distribution. This can be done by (1) compute the entropy of instantaneous adj matrix(categorical, same as ThreeWayGraphDist), + This computes the entropy of the variational distribution. + This can be done by (1) compute the entropy of instantaneous adj matrix(categorical, same as ThreeWayGraphDist), (2) compute the entropy of lagged adj matrix (Bernoulli dist), and (3) add them together. """ # Entropy for instantaneous dist, call super().entropy @@ -80,11 +87,12 @@ def entropy(self) -> torch.Tensor: entropies_inst = 0 # Entropy for lagged dist # batch_shape [lag], event_shape [num_nodes, num_nodes] - - dist_lag = td.Independent(td.Bernoulli(logits=self.logits_lag[1, ...] - self.logits_lag[0, ...]), 2) + + dist_lag = td.Independent(td.Bernoulli( + logits=self.logits_lag[1, ...] - self.logits_lag[0, ...]), 2) entropies_lag = dist_lag.entropy().sum() # entropies_lag = dist_lag.entropy().mean() - + return entropies_lag + entropies_inst def sample_A(self) -> torch.Tensor: @@ -101,9 +109,10 @@ def sample_A(self) -> torch.Tensor: # Sample instantaneous adj matrix if not self.disable_inst: adj_sample[0, ...] = self.zero_out_diagonal( - F.gumbel_softmax(self.logits, tau=self.tau_gumbel, hard=True, dim=0)[1, ...] + F.gumbel_softmax(self.logits, tau=self.tau_gumbel, + hard=True, dim=0)[1, ...] ) # shape (input_dim, input_dim) - + # Sample lagged adj matrix adj_sample[1:, ...] = F.gumbel_softmax(self.logits_lag, tau=self.tau_gumbel, hard=True, dim=0)[ 1, ... diff --git a/src/train.py b/src/train.py index 1dc15e4..85a6766 100644 --- a/src/train.py +++ b/src/train.py @@ -1,31 +1,26 @@ # standard libraries -from src.utils.utils import * -from src.utils.metrics_utils import evaluate_results -from tqdm import tqdm -import argparse -import numpy as np +import time import os + +import hydra +import numpy as np import torch -import random +from omegaconf import DictConfig from lightning.pytorch import Trainer, seed_everything from lightning.pytorch.loggers import CSVLogger, WandbLogger -from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping +from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.utilities.model_summary import ModelSummary from lightning.pytorch.utilities import rank_zero_only -from src.utils.config_utils import * +import wandb + +from src.utils.config_utils import add_all_attributes, add_attribute, generate_unique_name from src.utils.data_utils.dataloading_utils import load_data, get_dataset_path, create_save_name from src.utils.data_utils.data_format_utils import zero_out_diag_np +from src.utils.utils import write_results_to_disk +from src.utils.metrics_utils import evaluate_results -import uuid -import time -import wandb from src.model.generate_model import generate_model -from datetime import datetime -import csv -from omegaconf import DictConfig, OmegaConf, open_dict -import hydra -from hydra.utils import instantiate @hydra.main(version_base=None, config_path="../configs/", config_name="main.yaml") def run(cfg: DictConfig): @@ -38,13 +33,15 @@ def run(cfg: DictConfig): if dataset_path in cfg: model_config = cfg[dataset_path] else: - raise Exception("No model found in the config. Try running with option: python3 -m src.train +dataset= +=") - + raise Exception( + "No model found in the config. Try running with option: python3 -m src.train +dataset= +=") + cfg.dataset_dir = os.path.join(cfg.dataset_dir, dataset_path) add_all_attributes(cfg, model_config) train(cfg) + def train(cfg): print("Running config:") @@ -54,8 +51,9 @@ def train(cfg): # set seed seed_everything(cfg.random_seed) - X, adj_matrix, aggregated_graph, lag, data_dim = load_data(cfg.dataset, cfg.dataset_dir, cfg) - + X, adj_matrix, aggregated_graph, lag, data_dim = load_data( + cfg.dataset, cfg.dataset_dir, cfg) + add_attribute(cfg, 'lag', lag) add_attribute(cfg, 'aggregated_graph', aggregated_graph) add_attribute(cfg, 'num_nodes', X.shape[2]) @@ -65,18 +63,18 @@ def train(cfg): model = generate_model(cfg) # pass the dataset - model = model(full_dataset = X, - adj_matrices = adj_matrix) + model = model(full_dataset=X, + adj_matrices=adj_matrix) model_name = cfg.model if 'use_indices' in cfg: - f_path = os.path.join(cfg.dataset_dir, cfg.dataset, f'{cfg.use_indices}_seed={cfg.random_seed}.npy') + f_path = os.path.join(cfg.dataset_dir, cfg.dataset, + f'{cfg.use_indices}_seed={cfg.random_seed}.npy') mix_idx = torch.Tensor(np.load(f_path)) model.set_mixture_indices(mix_idx) - - training_needed = cfg.model == 'rhino' or cfg.model == 'mcd' + training_needed = cfg.model in ['rhino', 'mcd'] unique_name = generate_unique_name(cfg) csv_logger = CSVLogger("logs", name=unique_name) wandb_logger = WandbLogger( @@ -90,12 +88,13 @@ def train(cfg): csv_logger.log_hyperparams(cfg) if training_needed: - monitor_checkpoint_based_on = cfg.monitor_checkpoint_based_on # either val_loss or likelihood + # either val_loss or likelihood + monitor_checkpoint_based_on = cfg.monitor_checkpoint_based_on ckpt_choice = 'best' checkpoint_callback = ModelCheckpoint( save_top_k=1, monitor=monitor_checkpoint_based_on, mode="min", save_last=True) - + if len(cfg.gpu) > 1: strategy = 'ddp_find_unused_parameters_true' else: @@ -107,21 +106,21 @@ def train(cfg): val_every_n_epochs = 1 trainer = Trainer(max_epochs=cfg.num_epochs, - accelerator="gpu", - devices=cfg.gpu, - precision=cfg.precision, - logger=[csv_logger, wandb_logger], - callbacks=[checkpoint_callback], - strategy=strategy, - enable_progress_bar=True, - check_val_every_n_epoch=val_every_n_epochs) - + accelerator="gpu", + devices=cfg.gpu, + precision=cfg.precision, + logger=[csv_logger, wandb_logger], + callbacks=[checkpoint_callback], + strategy=strategy, + enable_progress_bar=True, + check_val_every_n_epoch=val_every_n_epochs) + summary = ModelSummary(model, max_depth=10) print(summary) - + if cfg.watch_gradients: wandb_logger.watch(model) - + start_time = time.time() trainer.fit(model=model) end_time = time.time() @@ -133,7 +132,7 @@ def train(cfg): print("WARNING: GPU specified, but baseline cannot use GPU.") trainer = Trainer(logger=[csv_logger, wandb_logger], accelerator='cpu') - + # get predictions full_dataloader = model.get_full_dataloader() @@ -169,12 +168,14 @@ def train(cfg): if training_needed and ckpt_choice == 'best': if model_name == 'mcd': - np.save(os.path.join('results', dataset, f'{model_name}_{checkpoint_callback.best_model_score.item()}_k{cfg.trainer.num_graphs}.npy'), scores) + np.save(os.path.join('results', dataset, + f'{model_name}_{checkpoint_callback.best_model_score.item()}_k{cfg.trainer.num_graphs}.npy'), scores) else: - np.save(os.path.join('results', dataset, f'{model_name}_{checkpoint_callback.best_model_score.item()}.npy'), scores) + np.save(os.path.join('results', dataset, + f'{model_name}_{checkpoint_callback.best_model_score.item()}.npy'), scores) else: np.save(os.path.join('results', dataset, f'{model_name}.npy'), scores) - + if model_name == 'mcd': true_cluster_indices, pred_cluster_indices = model.get_cluster_indices() else: @@ -182,22 +183,25 @@ def train(cfg): true_cluster_indices = None metrics = evaluate_results(scores=scores, - adj_matrix=adj_matrix, - predictions=predictions, - aggregated_graph=aggregated_graph, - true_cluster_indices=true_cluster_indices, - pred_cluster_indices=pred_cluster_indices) + adj_matrix=adj_matrix, + predictions=predictions, + aggregated_graph=aggregated_graph, + true_cluster_indices=true_cluster_indices, + pred_cluster_indices=pred_cluster_indices) # add the dataset name and model to the csv metrics['model'] = model_name + "_seed_" + str(seed) - if model_name == 'pcmci' or model_name == 'dynotears': - metrics['model'] += "_singlegraph_" + str(cfg.trainer.single_graph) + "_grouped_" + str(cfg.trainer.group_by_graph) + if model_name in ['pcmci', 'dynotears']: + metrics['model'] += "_singlegraph_" + \ + str(cfg.trainer.single_graph) + "_grouped_" + \ + str(cfg.trainer.group_by_graph) if model_name == 'mcd': - metrics['model'] += "_trueindex_" + str(cfg.trainer.use_correct_mixture_index) + metrics['model'] += "_trueindex_" + \ + str(cfg.trainer.use_correct_mixture_index) if 'use_indices' in cfg: metrics['model'] += '_' + cfg.use_indices - if (model_name == 'rhino' or model_name == 'mcd') and 'linear' in cfg.decoder and cfg.decoder.linear: - metrics['model'] += '_linearmodel' + if (model_name == ['rhino', 'mcd']) and 'linear' in cfg.decoder and cfg.decoder.linear: + metrics['model'] += '_linearmodel' if training_needed and ckpt_choice == 'best': metrics['best_loss'] = checkpoint_callback.best_model_score.item() metrics['dataset'] = dataset diff --git a/src/training/auglag.py b/src/training/auglag.py index 2409381..b775a97 100644 --- a/src/training/auglag.py +++ b/src/training/auglag.py @@ -4,13 +4,13 @@ from collections import deque from dataclasses import dataclass, field -from typing import Any, Optional, Union, Dict +from typing import Any, Optional, Dict import torch -from torch.optim import Optimizer import pytorch_lightning as pl from pytorch_lightning.utilities.types import STEP_OUTPUT + class AugLagLossCalculator(torch.nn.Module): def __init__(self, init_alpha: float, init_rho: float): super().__init__() @@ -19,8 +19,10 @@ def __init__(self, init_alpha: float, init_rho: float): self.alpha: torch.Tensor self.rho: torch.Tensor - self.register_buffer("alpha", torch.tensor(self.init_alpha, dtype=torch.float)) - self.register_buffer("rho", torch.tensor(self.init_rho, dtype=torch.float)) + self.register_buffer("alpha", torch.tensor( + self.init_alpha, dtype=torch.float)) + self.register_buffer("rho", torch.tensor( + self.init_rho, dtype=torch.float)) def forward(self, objective: torch.Tensor, constraint: torch.Tensor) -> torch.Tensor: return objective + self.alpha * constraint + self.rho * constraint * constraint / 2 @@ -55,7 +57,8 @@ class AugLagLRConfig: lr_update_lag: int = 500 lr_update_lag_best: int = 250 lr_init_dict: Dict[str, float] = field( - default_factory=lambda: {"vardist": 0.1, "functional_relationships": 0.0003, "noise_dist": 0.003, 'linear_causal_graph': 1} + default_factory=lambda: { + "vardist": 0.1, "functional_relationships": 0.0003, "noise_dist": 0.003, 'linear_causal_graph': 1} ) aggregation_period: int = 20 lr_factor: float = 0.1 @@ -73,6 +76,7 @@ class AugLagLRConfig: max_inner_steps: int = 3000 force_not_converged: bool = False + @dataclass class AugLagLRDYNOTEARSConfig: """Configuration parameters for the AuglagLR scheduler. @@ -93,9 +97,10 @@ class AugLagLRDYNOTEARSConfig: penalty_tolerance: float = 1e-5 max_opt_iter: int = 200 lr_init_dict: Dict[str, float] = field( - default_factory = lambda: {"w": 0.2, 'mixing_probs': 1} + default_factory=lambda: {"w": 0.2, 'mixing_probs': 1} ) - + + class AugLagLR: """A Pytorch like scheduler which performs the Augmented Lagrangian optimization procedure. @@ -116,7 +121,8 @@ def __init__(self, config: AugLagLRConfig) -> None: self._prev_lagrangian_penalty = torch.tensor(torch.inf) self._cur_lagrangian_penalty = torch.tensor(torch.inf) - self.loss_tracker: deque[torch.Tensor] = deque([], maxlen=config.aggregation_period) + self.loss_tracker: deque[torch.Tensor] = deque( + [], maxlen=config.aggregation_period) self._init_new_inner_optimisation() # Track whether auglag is disabled and the state of the loss when it was disabled @@ -139,10 +145,15 @@ def _is_inner_converged(self) -> bool: Returns: bool: Return True if converged, else False. """ - if self.step_counter >= self.config.max_inner_steps or self.num_lr_updates >= self.config.max_lr_down or self.last_best_step + self.config.inner_early_stopping_patience <= self.step_counter: - print("Step counter condition", self.step_counter >= self.config.max_inner_steps) - print("Update condition:", self.num_lr_updates >= self.config.max_lr_down) - print("Early stopping condition:", self.last_best_step + self.config.inner_early_stopping_patience <= self.step_counter) + if self.step_counter >= self.config.max_inner_steps \ + or self.num_lr_updates >= self.config.max_lr_down \ + or self.last_best_step + self.config.inner_early_stopping_patience <= self.step_counter: + print("Step counter condition", self.step_counter >= + self.config.max_inner_steps) + print("Update condition:", self.num_lr_updates >= + self.config.max_lr_down) + print("Early stopping condition:", self.last_best_step + + self.config.inner_early_stopping_patience <= self.step_counter) return ( self.step_counter >= self.config.max_inner_steps @@ -164,9 +175,12 @@ def _is_outer_converged(self) -> bool: return self.outer_opt_counter >= self.config.max_outer_steps if self.outer_opt_counter >= self.config.max_outer_steps or self.outer_below_penalty_tol >= self.config.patience_penalty_reached or self.outer_max_rho >= self.config.patience_max_rho: - print("Outer opt condition:", self.outer_opt_counter >= self.config.max_outer_steps) - print("Penalty condition:", self.outer_below_penalty_tol >= self.config.patience_penalty_reached) - print("Rho condition:", self.outer_max_rho >= self.config.patience_max_rho) + print("Outer opt condition:", self.outer_opt_counter >= + self.config.max_outer_steps) + print("Penalty condition:", self.outer_below_penalty_tol >= + self.config.patience_penalty_reached) + print("Rho condition:", self.outer_max_rho >= + self.config.patience_max_rho) return ( self.outer_opt_counter >= self.config.max_outer_steps @@ -204,11 +218,13 @@ def _update_lr(self, optimizer): for opt in optimizer: for param_group in opt.param_groups: param_group["lr"] *= self.config.lr_factor - print("Setting lr:", param_group["lr"], "for", param_group["name"]) + print("Setting lr:", + param_group["lr"], "for", param_group["name"]) else: for param_group in optimizer.param_groups: param_group["lr"] *= self.config.lr_factor - print("Setting lr:", param_group["lr"], "for", param_group["name"]) + print("Setting lr:", + param_group["lr"], "for", param_group["name"]) def reset_lr(self, optimizer): """Reset the learning rate of individual param groups from lr init dictionary. @@ -222,11 +238,13 @@ def reset_lr(self, optimizer): for opt in optimizer: for param_group in opt.param_groups: param_group["lr"] = self.config.lr_init_dict[param_group["name"]] - print("Resetting lr to", param_group["lr"], "for", param_group["name"]) + print("Resetting lr to", + param_group["lr"], "for", param_group["name"]) else: for param_group in optimizer.param_groups: param_group["lr"] = self.config.lr_init_dict[param_group["name"]] - print("Resetting lr to", param_group["lr"], "for", param_group["name"]) + print("Resetting lr to", + param_group["lr"], "for", param_group["name"]) def _update_lagrangian_params(self, loss: AugLagLossCalculator): """Update the lagrangian parameters (of the auglag routine) based on the dag constraint values observed. @@ -243,7 +261,8 @@ def _update_lagrangian_params(self, loss: AugLagLossCalculator): self.outer_max_rho += 1 if self._cur_lagrangian_penalty > self._prev_lagrangian_penalty * self.config.penalty_progress_rate: - print(f"Updating rho, dag penalty prev: {self._prev_lagrangian_penalty: .10f}") + print( + f"Updating rho, dag penalty prev: {self._prev_lagrangian_penalty: .10f}") loss.rho *= 10.0 print("Rho", loss.rho.item(), " Alpha", loss.alpha.item()) else: @@ -257,8 +276,10 @@ def _update_lagrangian_params(self, loss: AugLagLossCalculator): loss.alpha *= 5 # Update parameters and make sure to maintain the dtype and device - loss.alpha = torch.min(loss.alpha, torch.full_like(loss.alpha, self.config.safety_alpha)) - loss.rho = torch.min(loss.rho, torch.full_like(loss.rho, self.config.safety_rho)) + loss.alpha = torch.min(loss.alpha, torch.full_like( + loss.alpha, self.config.safety_alpha)) + loss.rho = torch.min(loss.rho, torch.full_like( + loss.rho, self.config.safety_rho)) def _is_auglag_converged(self, optimizer, loss: AugLagLossCalculator) -> bool: """Checks if the inner and outer loops have converged. If inner loop is converged, @@ -366,22 +387,24 @@ def step( """ if self.disabled: return False - assert torch.all(lagrangian_penalty >= 0), "auglag penalty must be non-negative" + assert torch.all(lagrangian_penalty >= + 0), "auglag penalty must be non-negative" self._update_loss_tracker(loss_value.detach()) self._cur_lagrangian_penalty = lagrangian_penalty.detach() self.step_counter += 1 if self.step_counter % 100 == 0: - print(f"Step:{self.step_counter} loss:{loss_value.item():.3f} "+\ - f"likelihood:{likelihood.item():.3f} dag:{self._cur_lagrangian_penalty.item():.3f} "+\ + print(f"Step:{self.step_counter} loss:{loss_value.item():.3f} " + + f"likelihood:{likelihood.item():.3f} dag:{self._cur_lagrangian_penalty.item():.3f} " + f"graph prior:{graph_prior.item():.3f} graph entropy:{graph_entropy.item():.3f}") self._check_best_loss() return self._is_auglag_converged(optimizer=optimizer, loss=loss) - + + class AuglagLRCallback(pl.Callback): """Wrapper Class to make the Auglag Learning Rate Scheduler compatible with Pytorch Lightning""" - def __init__(self, scheduler: AugLagLR, log_auglag: bool = False, disabled_epochs = None): + def __init__(self, scheduler: AugLagLR, log_auglag: bool = False, disabled_epochs=None): """ Args: scheduler: The auglag learning rate scheduler to wrap. @@ -434,5 +457,5 @@ def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) "last_best_step": float(self.scheduler.last_best_step), "last_lr_update_step": float(self.scheduler.last_lr_update_step), } - pl_module.log_dict(auglag_state, on_epoch=True, rank_zero_only=True, prog_bar=False) - + pl_module.log_dict(auglag_state, on_epoch=True, + rank_zero_only=True, prog_bar=False) diff --git a/src/utils/causality_utils.py b/src/utils/causality_utils.py index c6de642..40bde8c 100644 --- a/src/utils/causality_utils.py +++ b/src/utils/causality_utils.py @@ -2,14 +2,10 @@ Borrowed from github.com/microsoft/causica """ -from itertools import product -from typing import Any, Dict, Optional, Type, Union +from typing import Any, Dict, Optional, Union import numpy as np import torch -import torch.distributions as td -from sklearn import metrics - def convert_temporal_to_static_adjacency_matrix( adj_matrix: np.ndarray, conversion_type: str, fill_value: Union[float, int] = 0.0 @@ -43,11 +39,13 @@ def convert_temporal_to_static_adjacency_matrix( if conversion_type == "full_time": block_fill_value = np.full((n_nodes, n_nodes), fill_value) else: - block_fill_value = np.full((batch_dim, n_lag * n_nodes, (n_lag - 1) * n_nodes), fill_value) + block_fill_value = np.full( + (batch_dim, n_lag * n_nodes, (n_lag - 1) * n_nodes), fill_value) if conversion_type == "full_time": static_adj = np.sum( - np.stack([np.kron(np.diag(np.ones(n_lag - i), k=i), adj_matrix[:, i, :, :]) for i in range(n_lag)], axis=1), + np.stack([np.kron(np.diag(np.ones(n_lag - i), k=i), + adj_matrix[:, i, :, :]) for i in range(n_lag)], axis=1), axis=1, ) # [N, n_lag*from, n_lag*to] static_adj += np.kron( @@ -62,7 +60,8 @@ def convert_temporal_to_static_adjacency_matrix( -1, n_lag * n_nodes, n_nodes ) # [N, (lag+1)*num_node, num_node] # Static graph - static_adj = np.concatenate((block_fill_value, block_column), axis=2) # [N, (lag+1)*num_node, (lag+1)*num_node] + # [N, (lag+1)*num_node, (lag+1)*num_node] + static_adj = np.concatenate((block_fill_value, block_column), axis=2) return np.squeeze(static_adj) @@ -88,17 +87,21 @@ def approximate_maximal_acyclic_subgraph(adj_matrix: np.ndarray, n_samples: int # assign each node with a order adj_dag = np.zeros_like(adj_matrix) for _ in range(n_samples): - random_order = np.expand_dims(np.random.permutation(adj_matrix.shape[0]), 0) + random_order = np.expand_dims( + np.random.permutation(adj_matrix.shape[0]), 0) # subgraph with only forward edges defined by the assigned order - adj_forward = ((random_order.T > random_order).astype(int)) * adj_matrix + adj_forward = ( + (random_order.T > random_order).astype(int)) * adj_matrix # subgraph with only backward edges defined by the assigned order - adj_backward = ((random_order.T < random_order).astype(int)) * adj_matrix + adj_backward = ( + (random_order.T < random_order).astype(int)) * adj_matrix # return the subgraph with the least deleted edges adj_dag_n = adj_forward if adj_backward.sum() < adj_forward.sum() else adj_backward if adj_dag_n.sum() > adj_dag.sum(): adj_dag = adj_dag_n return adj_dag + def int2binlist(i: int, n_bits: int): """ Convert integer to list of ints with values in {0, 1} @@ -107,6 +110,7 @@ def int2binlist(i: int, n_bits: int): str_list = list(np.binary_repr(i, n_bits)) return [int(i) for i in str_list] + def cpdag2dags(cp_mat: np.ndarray, samples: Optional[int] = None) -> np.ndarray: """ Compute all possible DAGs contained within a Markov equivalence class, given by a CPDAG @@ -130,16 +134,18 @@ def cpdag2dags(cp_mat: np.ndarray, samples: Optional[int] = None) -> np.ndarray: # prune cycles if the matrix of determined edges is not a dag if dag_pen_np(cp_determined_subgraph.copy()) != 0.0: - cp_determined_subgraph = approximate_maximal_acyclic_subgraph(cp_determined_subgraph, 1000) + cp_determined_subgraph = approximate_maximal_acyclic_subgraph( + cp_determined_subgraph, 1000) # number of parent nodes for each node under the well determined matrix - N_in_nodes = cp_determined_subgraph.sum(axis=0) + n_in_nodes = cp_determined_subgraph.sum(axis=0) # lower triangular version of cycles edges: only keep cycles in one direction. cycles_tril = np.tril(cycle_mat, k=-1) # indices of potential new edges - undetermined_idx_mat = np.array(np.nonzero(cycles_tril)).T # (N_undedetermined, 2) + undetermined_idx_mat = np.array(np.nonzero( + cycles_tril)).T # (N_undedetermined, 2) # number of undetermined edges N_undetermined = int(cycles_tril.sum()) @@ -163,18 +169,22 @@ def cpdag2dags(cp_mat: np.ndarray, samples: Optional[int] = None) -> np.ndarray: mask = np.array(int2binlist(mask_index, N_undetermined)) # extract list of indices which our new edges are pointing into - incoming_edges = np.take_along_axis(undetermined_idx_mat, mask[:, None], axis=1).squeeze() + incoming_edges = np.take_along_axis( + undetermined_idx_mat, mask[:, None], axis=1).squeeze() # check if multiple edges are pointing at same node - _, unique_counts = np.unique(incoming_edges, return_index=False, return_inverse=False, return_counts=True) + _, unique_counts = np.unique( + incoming_edges, return_index=False, return_inverse=False, return_counts=True) # check if new colider has been created by checkig if multiple edges point at same node or if new edge points at existing child node - new_colider = np.any(unique_counts > 1) or np.any(N_in_nodes[incoming_edges] > 0) + new_colider = np.any(unique_counts > 1) or np.any( + n_in_nodes[incoming_edges] > 0) if not new_colider: # get indices of new edges by sampling from lower triangular mat and upper triangular according to indices edge_selection = undetermined_idx_mat.copy() - edge_selection[mask == 0, :] = np.fliplr(edge_selection[mask == 0, :]) + edge_selection[mask == 0, :] = np.fliplr( + edge_selection[mask == 0, :]) # add new edges to matrix and add to dag list new_dag = cp_determined_subgraph.copy() @@ -187,4 +197,4 @@ def cpdag2dags(cp_mat: np.ndarray, samples: Optional[int] = None) -> np.ndarray: if len(dag_list) == 0: dag_list.append(cp_determined_subgraph) - return np.stack(dag_list, axis=0) \ No newline at end of file + return np.stack(dag_list, axis=0) diff --git a/src/utils/config_utils.py b/src/utils/config_utils.py index e260335..8da58da 100644 --- a/src/utils/config_utils.py +++ b/src/utils/config_utils.py @@ -1,7 +1,8 @@ from datetime import datetime -from omegaconf import DictConfig, OmegaConf, open_dict +from omegaconf import DictConfig, open_dict from sklearn.model_selection import ParameterGrid + def generate_unique_name(config): # generate unique name based on the config run_name = config['model']+"_"+config['dataset'] + \ @@ -9,22 +10,24 @@ def generate_unique_name(config): run_name += datetime.now().strftime("%Y%m%d_%H_%M_%S") return run_name + def read_optional(config, arg, default): if arg in config: return config[arg] - else: - return default + return default def add_attribute(config: DictConfig, name, val): with open_dict(config): config[name] = val + def add_all_attributes(cfg, cfg2): # add all attributes from cfg2 to cfg for key in cfg2: add_attribute(cfg, key, cfg2[key]) - + + def build_subdictionary(hyperparameters, loop_hyperparameters): """ Given dictionary of hyperparameters (where some of the values may be lists) and a list of keys diff --git a/src/utils/data_gen/data_generation_utils.py b/src/utils/data_gen/data_generation_utils.py index 436004b..9163cc1 100644 --- a/src/utils/data_gen/data_generation_utils.py +++ b/src/utils/data_gen/data_generation_utils.py @@ -125,7 +125,8 @@ def random_acyclic_orientation(B_und: np.ndarray) -> np.ndarray: def generate_single_graph(num_nodes: int, graph_type: str, graph_config: dict, is_DAG: bool = True) -> np.ndarray: """ This will generate a single adjacency matrix following different graph generation methods (specified by graph_type, can be "ER", "SF", "SBM"). - graph_config specifes the additional configurations for graph_type. For example, for "ER", the config dict keys can be {"p", "m", "directed", "loop"}, + graph_config specifes the additional configurations for graph_type. + For example, for "ER", the config dict keys can be {"p", "m", "directed", "loop"}, refer to igraph for details. is_DAG is to ensure the generated graph is a DAG by lower-trianguler the adj, followed by a permutation. Note that SBM will no longer be a proper SBM if is_DAG=True Args: @@ -443,10 +444,11 @@ def generate_cts_temporal_data( noise_function_type: str = "spline", save_data: bool = True, base_noise_type: str = "gaussian", - temporal_graphs = None + temporal_graphs=None ) -> np.ndarray: """ - This will generate continuous time-series data (with history-depdendent noise). It will start to collect the data after the burnin_length for stationarity. + This will generate continuous time-series data (with history-depdendent noise). + It will start to collect the data after the burnin_length for stationarity. Args: path: The output dir path. series_length: The time series length to be generated. @@ -632,8 +634,7 @@ def sample_function(input_dim: int, function_type: str) -> Callable: return sample_mlp_noise(input_dim) elif function_type == 'linear': return sample_linear(input_dim) - else: - raise ValueError(f"Unsupported function type: {function_type}") + raise ValueError(f"Unsupported function type: {function_type}") def sample_inverse_noise_spline(input_dim): @@ -718,15 +719,18 @@ def func(X): return func + def sample_linear(input_dim): # sample weights - W = np.random.binomial(n=1, p=0.5, size=(input_dim))*np.random.uniform(0.1, 0.5, size=(input_dim)) - + W = np.random.binomial(n=1, p=0.5, size=(input_dim)) * \ + np.random.uniform(0.1, 0.5, size=(input_dim)) + def func(X): return X@W return func + def zero_func() -> np.ndarray: return np.zeros(1) @@ -752,7 +756,6 @@ def generate_name( else: flag = "HistDep" - if disable_inst: file_name = ( f"{graph_type[0]}_{graph_type[1]}_num_graphs_{num_graphs}_lag_{lag}_dim_{num_nodes}_{flag}_{noise_level}_{function_type}_" diff --git a/src/utils/data_gen/generate_perturb_syn.py b/src/utils/data_gen/generate_perturb_syn.py index 1e6461b..3adaa96 100644 --- a/src/utils/data_gen/generate_perturb_syn.py +++ b/src/utils/data_gen/generate_perturb_syn.py @@ -1,19 +1,20 @@ +import os import argparse import yaml -import os -import json -from data_generation_utils import generate_cts_temporal_data, generate_name, set_random_seed, generate_temporal_graph +from data_generation_utils import generate_cts_temporal_data, set_random_seed, generate_temporal_graph import cdt import numpy as np import networkx as nx + def calc_dist(adj_matrix): unique_matrices = np.unique(adj_matrix, axis=0) distances = [] for i in range(unique_matrices.shape[0]): for j in range(i): - distances.append(cdt.metrics.SHD(unique_matrices[i], unique_matrices[j])) + distances.append(cdt.metrics.SHD( + unique_matrices[i], unique_matrices[j])) mean_dist = np.mean(distances) min_dist = np.min(distances) max_dist = np.max(distances) @@ -21,6 +22,7 @@ def calc_dist(adj_matrix): return mean_dist, std_dist, min_dist, max_dist + def perturb_graph(G, p): retry_counter = 0 while True: @@ -30,9 +32,8 @@ def perturb_graph(G, p): nxG = nx.DiGraph(perturbed[0]) if nx.is_directed_acyclic_graph(nxG): break - else: - retry_counter += 1 - + retry_counter += 1 + if retry_counter >= 200000: assert False, "Cannot generate DAG, try a lower value of p" @@ -42,14 +43,15 @@ def perturb_graph(G, p): def main(config_file): # read the yaml file - with open(config_file) as f: + with open(config_file, 'r', encoding="utf-8") as f: data_config = yaml.load(f, Loader=yaml.FullLoader) - + series_length = int(data_config["num_timesteps"]) burnin_length = int(data_config["burnin_length"]) num_samples = int(data_config["num_samples"]) disable_inst = bool(data_config["disable_inst"]) - graph_type = [data_config["inst_graph_type"], data_config["lag_graph_type"]] + graph_type = [data_config["inst_graph_type"], + data_config["lag_graph_type"]] p_array = data_config['p_array'] connection_factor = 1 @@ -77,22 +79,25 @@ def main(config_file): {"m": N * 2 * connection_factor if not disable_inst else 0, "directed": True}, {"m": N * connection_factor, "directed": True}, ] - G = generate_temporal_graph(N, graph_type, graph_config, lag=2).astype(int) + G = generate_temporal_graph( + N, graph_type, graph_config, lag=2).astype(int) N = int(N) for N_G in num_graphs: N_G = int(N_G) print(f"Generating dataset for N={N}, num_graphs={N_G}") - + for p in p_array: graphs = [] for i in range(N_G): print(f"Generating graph {i}/{N_G}") Gtilde = perturb_graph(G, p) graphs.append(Gtilde) - - mean_dist, std_dist, min_dist, max_dist = calc_dist(np.array(graphs)) - print(f"Perturbation,{N},{N_G},{mean_dist},{std_dist},{min_dist},{max_dist},{p}") + + mean_dist, std_dist, min_dist, max_dist = calc_dist( + np.array(graphs)) + print( + f"Perturbation,{N},{N_G},{mean_dist},{std_dist},{min_dist},{max_dist},{p}") folder_name = f"perturb_N{N}_K{N_G}_p{p}_seed{seed}" path = os.path.join(save_dir, folder_name) diff --git a/src/utils/data_gen/generate_stock.py b/src/utils/data_gen/generate_stock.py index b67f971..e2df15d 100644 --- a/src/utils/data_gen/generate_stock.py +++ b/src/utils/data_gen/generate_stock.py @@ -1,9 +1,10 @@ +import argparse +import os from yahoofinancials import YahooFinancials import pandas as pd import numpy as np import tqdm -import argparse -import os + def process_stock_price(tickers, start_date, end_date): yfs = [YahooFinancials(t) for t in tickers] @@ -25,6 +26,7 @@ def process_stock_price(tickers, start_date, end_date): def get_log_returns(X): return np.diff(np.log(X)) + def standardize(X): mu = np.mean(X, axis=1) std = np.std(X, axis=1) @@ -32,6 +34,7 @@ def standardize(X): std = np.repeat(std[:, np.newaxis], repeats=X.shape[1], axis=1) return (X-mu)/std + def generate_stock(args): df = pd.read_csv(args.stock_list_file) tickers = df['Symbol'].tolist() @@ -40,13 +43,13 @@ def generate_stock(args): X = get_log_returns(X) X = standardize(X) - + D = X.shape[0] T = X.shape[1] L = args.chunk_size dat = np.zeros((D, T//L, L)) date_array = [] - + for i in range(T//L): dat[:, i] = X[:, L*i:(i+1)*L] date_array.append(dates[0][L*i]) @@ -61,19 +64,23 @@ def generate_stock(args): for sector in df['Sector'].unique(): X_sector = dat[:, :, df[df['Sector'] == sector].index.values] - np.save(os.path.join(args.save_dir, f'X_{sector.replace(" ", "")}.npy'), X_sector) + np.save(os.path.join(args.save_dir, + f'X_{sector.replace(" ", "")}.npy'), X_sector) - with open(os.path.join(args.save_dir, 'dates.csv'), 'w') as f: + with open(os.path.join(args.save_dir, 'dates.csv'), 'w', encoding="utf-8") as f: for i in range(len(date_array)-1): f.write(f"{date_array[i]} to {date_array[i+1]}\n") + if __name__ == '__main__': - parser = argparse.ArgumentParser("Stock data generator from Yahoo Financials") + parser = argparse.ArgumentParser( + "Stock data generator from Yahoo Financials") parser.add_argument("--start_date", type=str, default='2016-01-01') parser.add_argument("--end_date", type=str, default='2023-07-01') - parser.add_argument('--chunk_size', type=int, default=31, help='Number of days to chunk together into one sample. Default is 31 days') + parser.add_argument('--chunk_size', type=int, default=31, + help='Number of days to chunk together into one sample. Default is 31 days') parser.add_argument('--stock_list_file', type=str) parser.add_argument('--save_dir', type=str) - + args = parser.parse_args() generate_stock(args) diff --git a/src/utils/data_gen/generate_synthetic_data.py b/src/utils/data_gen/generate_synthetic_data.py index d37d0b4..60c17b1 100644 --- a/src/utils/data_gen/generate_synthetic_data.py +++ b/src/utils/data_gen/generate_synthetic_data.py @@ -1,22 +1,21 @@ +import os import argparse import yaml -import os -import json - from data_generation_utils import generate_cts_temporal_data, generate_name, set_random_seed def main(config_file): # read the yaml file - with open(config_file) as f: + with open(config_file, encoding="utf-8") as f: data_config = yaml.load(f, Loader=yaml.FullLoader) - + series_length = int(data_config["num_timesteps"]) burnin_length = int(data_config["burnin_length"]) num_samples = int(data_config["num_samples"]) disable_inst = bool(data_config["disable_inst"]) - graph_type = [data_config["inst_graph_type"], data_config["lag_graph_type"]] + graph_type = [data_config["inst_graph_type"], + data_config["lag_graph_type"]] connection_factor = 1 @@ -46,7 +45,8 @@ def main(config_file): N = int(N) N_G = int(N_G) graph_config = [ - {"m": N * 2 * connection_factor if not disable_inst else 0, "directed": True}, + {"m": N * 2 * connection_factor if not disable_inst else 0, + "directed": True}, {"m": N * connection_factor, "directed": True}, ] @@ -67,7 +67,6 @@ def main(config_file): ) path = os.path.join(save_dir, folder_name) - generate_cts_temporal_data( path=path, num_graphs=N_G, diff --git a/src/utils/data_gen/process_dream3.py b/src/utils/data_gen/process_dream3.py index 0800f4c..68aecd6 100644 --- a/src/utils/data_gen/process_dream3.py +++ b/src/utils/data_gen/process_dream3.py @@ -1,11 +1,11 @@ -import scipy.io as sio import argparse import os + import numpy as np -import math import torch import pandas as pd + def process_ts(ts, timepoints, N_subjects): N_nodes = ts.shape[1] X = np.zeros((N_subjects, timepoints, N_nodes)) @@ -15,6 +15,7 @@ def process_ts(ts, timepoints, N_subjects): return X + def process_adj_matrix(net, size): A = np.zeros((size, size)) for a, b, c in net.values: @@ -23,6 +24,7 @@ def process_adj_matrix(net, size): A[src, dest] = 1 return A + def split_by_trajectory(X, A, T=21): time_len = X.shape[0] N = X.shape[1] @@ -35,63 +37,84 @@ def split_by_trajectory(X, A, T=21): return data, adj_matrix + def process_dream3(args): for size in [10, 50, 100]: - X1 = torch.load(os.path.join(args.dataset_dir, 'Dream3TensorData', f'Size{size}Ecoli1.pt'))['TsData'].numpy() - A1 = pd.read_table(os.path.join(args.dataset_dir, 'TrueGeneNetworks', f'InSilicoSize{size}-Ecoli1.tsv'), header=None) + X1 = torch.load(os.path.join( + args.dataset_dir, 'Dream3TensorData', f'Size{size}Ecoli1.pt'))['TsData'].numpy() + A1 = pd.read_table(os.path.join( + args.dataset_dir, 'TrueGeneNetworks', f'InSilicoSize{size}-Ecoli1.tsv'), header=None) A1 = process_adj_matrix(A1, size) X1, A1 = split_by_trajectory(X1, A1) - X2 = torch.load(os.path.join(args.dataset_dir, 'Dream3TensorData', f'Size{size}Ecoli2.pt'))['TsData'].numpy() - A2 = pd.read_table(os.path.join(args.dataset_dir, 'TrueGeneNetworks', f'InSilicoSize{size}-Ecoli2.tsv'), header=None) + X2 = torch.load(os.path.join( + args.dataset_dir, 'Dream3TensorData', f'Size{size}Ecoli2.pt'))['TsData'].numpy() + A2 = pd.read_table(os.path.join( + args.dataset_dir, 'TrueGeneNetworks', f'InSilicoSize{size}-Ecoli2.tsv'), header=None) A2 = process_adj_matrix(A2, size) X2, A2 = split_by_trajectory(X2, A2) - + if not os.path.exists(os.path.join(args.save_dir, f'ecoli_{size}', 'grouped_by_matrix')): - os.makedirs(os.path.join(args.save_dir, f'ecoli_{size}', 'grouped_by_matrix')) - - np.savez(os.path.join(args.save_dir, f'ecoli_{size}', 'grouped_by_matrix', 'ecoli_1.npz'), X=X1, adj_matrix=A1) - np.savez(os.path.join(args.save_dir, f'ecoli_{size}', 'grouped_by_matrix', 'ecoli_2.npz'), X=X2, adj_matrix=A2) - + os.makedirs(os.path.join(args.save_dir, + f'ecoli_{size}', 'grouped_by_matrix')) + + np.savez(os.path.join( + args.save_dir, f'ecoli_{size}', 'grouped_by_matrix', 'ecoli_1.npz'), X=X1, adj_matrix=A1) + np.savez(os.path.join( + args.save_dir, f'ecoli_{size}', 'grouped_by_matrix', 'ecoli_2.npz'), X=X2, adj_matrix=A2) + X = np.concatenate((X1, X2), axis=0) A = np.concatenate((A1, A2), axis=0) - np.savez(os.path.join(args.save_dir, f'ecoli_{size}', 'ecoli.npz'), X=X, adj_matrix=A) + np.savez(os.path.join(args.save_dir, + f'ecoli_{size}', 'ecoli.npz'), X=X, adj_matrix=A) - X11 = torch.load(os.path.join(args.dataset_dir, 'Dream3TensorData', f'Size{size}Yeast1.pt'))['TsData'].numpy() - A11 = pd.read_table(os.path.join(args.dataset_dir, 'TrueGeneNetworks', f'InSilicoSize{size}-Yeast1.tsv'), header=None) + X11 = torch.load(os.path.join( + args.dataset_dir, 'Dream3TensorData', f'Size{size}Yeast1.pt'))['TsData'].numpy() + A11 = pd.read_table(os.path.join( + args.dataset_dir, 'TrueGeneNetworks', f'InSilicoSize{size}-Yeast1.tsv'), header=None) A11 = process_adj_matrix(A11, size) X11, A11 = split_by_trajectory(X11, A11) - X21 = torch.load(os.path.join(args.dataset_dir, 'Dream3TensorData', f'Size{size}Yeast2.pt'))['TsData'].numpy() - A21 = pd.read_table(os.path.join(args.dataset_dir, 'TrueGeneNetworks', f'InSilicoSize{size}-Yeast2.tsv'), header=None) + X21 = torch.load(os.path.join( + args.dataset_dir, 'Dream3TensorData', f'Size{size}Yeast2.pt'))['TsData'].numpy() + A21 = pd.read_table(os.path.join( + args.dataset_dir, 'TrueGeneNetworks', f'InSilicoSize{size}-Yeast2.tsv'), header=None) A21 = process_adj_matrix(A21, size) X21, A21 = split_by_trajectory(X21, A21) - X31 = torch.load(os.path.join(args.dataset_dir, 'Dream3TensorData', f'Size{size}Yeast3.pt'))['TsData'].numpy() - A31 = pd.read_table(os.path.join(args.dataset_dir, 'TrueGeneNetworks', f'InSilicoSize{size}-Yeast3.tsv'), header=None) + X31 = torch.load(os.path.join( + args.dataset_dir, 'Dream3TensorData', f'Size{size}Yeast3.pt'))['TsData'].numpy() + A31 = pd.read_table(os.path.join( + args.dataset_dir, 'TrueGeneNetworks', f'InSilicoSize{size}-Yeast3.tsv'), header=None) A31 = process_adj_matrix(A31, size) X31, A31 = split_by_trajectory(X31, A31) if not os.path.exists(os.path.join(args.save_dir, f'yeast_{size}', 'grouped_by_matrix')): - os.makedirs(os.path.join(args.save_dir, f'yeast_{size}', 'grouped_by_matrix')) - - np.savez(os.path.join(args.save_dir, f'yeast_{size}', 'grouped_by_matrix', 'yeast_1.npz'), X=X11, adj_matrix=A11) - np.savez(os.path.join(args.save_dir, f'yeast_{size}', 'grouped_by_matrix', 'yeast_2.npz'), X=X21, adj_matrix=A21) - np.savez(os.path.join(args.save_dir, f'yeast_{size}', 'grouped_by_matrix', 'yeast_3.npz'), X=X31, adj_matrix=A31) + os.makedirs(os.path.join(args.save_dir, + f'yeast_{size}', 'grouped_by_matrix')) + + np.savez(os.path.join( + args.save_dir, f'yeast_{size}', 'grouped_by_matrix', 'yeast_1.npz'), X=X11, adj_matrix=A11) + np.savez(os.path.join( + args.save_dir, f'yeast_{size}', 'grouped_by_matrix', 'yeast_2.npz'), X=X21, adj_matrix=A21) + np.savez(os.path.join( + args.save_dir, f'yeast_{size}', 'grouped_by_matrix', 'yeast_3.npz'), X=X31, adj_matrix=A31) X = np.concatenate((X11, X21, X31), axis=0) A = np.concatenate((A11, A21, A31), axis=0) print(X.shape) print(A.shape) - np.savez(os.path.join(args.save_dir, f'yeast_{size}', 'yeast.npz'), X=X, adj_matrix=A) - + np.savez(os.path.join(args.save_dir, + f'yeast_{size}', 'yeast.npz'), X=X, adj_matrix=A) + # save combined X = np.concatenate((X1, X2, X11, X21, X31), axis=0) A = np.concatenate((A1, A2, A11, A21, A31), axis=0) print(X.shape) print(A.shape) - np.savez(os.path.join(args.save_dir, f'combined_{size}.npz'), X=X, adj_matrix=A) - + np.savez(os.path.join(args.save_dir, + f'combined_{size}.npz'), X=X, adj_matrix=A) + if __name__ == "__main__": parser = argparse.ArgumentParser() diff --git a/src/utils/data_gen/process_netsim.py b/src/utils/data_gen/process_netsim.py index 23c7a27..dd94b5a 100644 --- a/src/utils/data_gen/process_netsim.py +++ b/src/utils/data_gen/process_netsim.py @@ -1,10 +1,12 @@ -import scipy.io as sio -import argparse import os -import numpy as np import math import random +import scipy.io as sio +import argparse +import numpy as np + + def process_ts(ts, timepoints, N_subjects): N_nodes = ts.shape[1] X = np.zeros((N_subjects, timepoints, N_nodes)) @@ -14,9 +16,11 @@ def process_ts(ts, timepoints, N_subjects): return X + def process_adj_matrix(net): return (np.abs(np.swapaxes(net, 1, 2)) > 0).astype(int) - + + def process_netsim(args): seed = 0 simulations = range(1, 29) @@ -28,7 +32,7 @@ def process_netsim(args): for i in simulations: mat = sio.loadmat(os.path.join(args.dataset_dir, f'sim{i}.mat')) timepoints = mat['Ntimepoints'][0, 0] - N_subjects = mat['Nsubjects'][0, 0] + N_subjects = mat['Nsubjects'][0, 0] ts = mat['ts'] net = mat['net'] N_nodes = ts.shape[1] @@ -39,44 +43,50 @@ def process_netsim(args): N_t_dict[(N_nodes, timepoints)] = {} N_t_dict[(N_nodes, timepoints)]['X'] = [] N_t_dict[(N_nodes, timepoints)]['adj_matrix'] = [] - + N_t_dict[(N_nodes, timepoints)]['X'].append(X) N_t_dict[(N_nodes, timepoints)]['adj_matrix'].append(adj_matrix) - + for j in range(X.shape[0]): if (adj_matrix[j].tobytes(), X.shape[1]) not in X_dict_by_matrix: X_dict_by_matrix[(adj_matrix[j].tobytes(), X.shape[1])] = [] - X_dict_by_matrix[(adj_matrix[j].tobytes(), X.shape[1])].append(X[j]) + X_dict_by_matrix[(adj_matrix[j].tobytes(), + X.shape[1])].append(X[j]) if N_nodes == 15 and timepoints == 200: print("SIMULATION", i) for N_nodes, timepoints in N_t_dict: X = np.concatenate(N_t_dict[(N_nodes, timepoints)]['X'], axis=0) - adj_matrix = np.concatenate(N_t_dict[(N_nodes, timepoints)]['adj_matrix'], axis=0) - np.savez(os.path.join(args.save_dir, f'netsim_{N_nodes}_{timepoints}.npz'), X=X, adj_matrix=adj_matrix) + adj_matrix = np.concatenate( + N_t_dict[(N_nodes, timepoints)]['adj_matrix'], axis=0) + np.savez(os.path.join( + args.save_dir, f'netsim_{N_nodes}_{timepoints}.npz'), X=X, adj_matrix=adj_matrix) counts = {} for key, T in X_dict_by_matrix: N = int(math.sqrt(len(np.frombuffer(key, dtype=int)))) if (N, T) not in counts: counts[(N, T)] = 0 - + adj_matrix = np.array(np.frombuffer(key, dtype=int)).reshape(N, N) if not os.path.exists(os.path.join(args.save_dir, 'grouped_by_matrix', f'{N}_{T}')): - os.makedirs(os.path.join(args.save_dir, 'grouped_by_matrix', f'{N}_{T}')) + os.makedirs(os.path.join(args.save_dir, + 'grouped_by_matrix', f'{N}_{T}')) X = np.array(X_dict_by_matrix[(key, T)]) - np.savez(os.path.join(args.save_dir, 'grouped_by_matrix', f'{N}_{T}', f'netsim_{counts[(N, T)]}.npz'), X=X, adj_matrix=adj_matrix) + np.savez(os.path.join(args.save_dir, 'grouped_by_matrix', + f'{N}_{T}', f'netsim_{counts[(N, T)]}.npz'), X=X, adj_matrix=adj_matrix) counts[(N, T)] += 1 # add the permutations for N_nodes in [15, 50]: timepoints = 200 num_graphs = 3 - permutation_pool = [np.random.permutation(np.arange(N_nodes)) for i in range(num_graphs)] + permutation_pool = [np.random.permutation( + np.arange(N_nodes)) for i in range(num_graphs)] X = [] adj_matrix = [] - + for i in range(len(N_t_dict[(N_nodes, timepoints)]['X'][0])): I = random.choice(permutation_pool) x = N_t_dict[(N_nodes, timepoints)]['X'][0][i] @@ -85,14 +95,15 @@ def process_netsim(args): G = N_t_dict[(N_nodes, timepoints)]['adj_matrix'][0][i] G = G[I][:, I] - + X.append(x) adj_matrix.append(G) X = np.array(X) adj_matrix = np.array(adj_matrix) - np.savez(os.path.join(args.save_dir, f'netsim_{N_nodes}_{timepoints}_permuted.npz'), X=X, adj_matrix=adj_matrix) - + np.savez(os.path.join( + args.save_dir, f'netsim_{N_nodes}_{timepoints}_permuted.npz'), X=X, adj_matrix=adj_matrix) + if __name__ == "__main__": parser = argparse.ArgumentParser() diff --git a/src/utils/data_gen/splines.py b/src/utils/data_gen/splines.py index 3bfb26e..cf51e03 100644 --- a/src/utils/data_gen/splines.py +++ b/src/utils/data_gen/splines.py @@ -140,29 +140,35 @@ def RQS( theta_one_minus_theta = root * (1 - root) denominator = input_delta + ( - (input_derivatives + input_derivatives_plus_one - 2 * input_delta) * theta_one_minus_theta + (input_derivatives + input_derivatives_plus_one - + 2 * input_delta) * theta_one_minus_theta ) derivative_numerator = input_delta.pow(2) * ( input_derivatives_plus_one * root.pow(2) + 2 * input_delta * theta_one_minus_theta + input_derivatives * (1 - root).pow(2) ) - logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) + logabsdet = torch.log(derivative_numerator) - \ + 2 * torch.log(denominator) return outputs, -logabsdet - else: - theta = (inputs - input_cumwidths) / input_bin_widths - theta_one_minus_theta = theta * (1 - theta) - - numerator = input_heights * (input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta) - denominator = input_delta + ( - (input_derivatives + input_derivatives_plus_one - 2 * input_delta) * theta_one_minus_theta - ) - outputs = input_cumheights + numerator / denominator - derivative_numerator = input_delta.pow(2) * ( - input_derivatives_plus_one * theta.pow(2) - + 2 * input_delta * theta_one_minus_theta - + input_derivatives * (1 - theta).pow(2) - ) - logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) - return outputs, logabsdet + # else + theta = (inputs - input_cumwidths) / input_bin_widths + theta_one_minus_theta = theta * (1 - theta) + + numerator = input_heights * \ + (input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta) + denominator = input_delta + ( + (input_derivatives + input_derivatives_plus_one - + 2 * input_delta) * theta_one_minus_theta + ) + outputs = input_cumheights + numerator / denominator + + derivative_numerator = input_delta.pow(2) * ( + input_derivatives_plus_one * theta.pow(2) + + 2 * input_delta * theta_one_minus_theta + + input_derivatives * (1 - theta).pow(2) + ) + logabsdet = torch.log(derivative_numerator) - \ + 2 * torch.log(denominator) + return outputs, logabsdet diff --git a/src/utils/data_utils/data_format_utils.py b/src/utils/data_utils/data_format_utils.py index 907f191..2d5b2ff 100644 --- a/src/utils/data_utils/data_format_utils.py +++ b/src/utils/data_utils/data_format_utils.py @@ -1,6 +1,7 @@ import torch import numpy as np + def convert_data_to_timelagged(X, lag): """ Converts data with shape (n_samples, timesteps, num_nodes, data_dim) to two tensors, @@ -20,9 +21,11 @@ def convert_data_to_timelagged(X, lag): X_indices[i*n_samples:(i+1)*n_samples] = np.arange(n_samples) return X_history, X_input, X_indices + def get_adj_matrix_id(A): return np.unique(A, axis=(0), return_inverse=True) - + + def convert_adj_to_timelagged(A, lag, n_fragments, aggregated_graph=False, return_indices=True): """ Converts adjacency matrix with shape (n_samples, (lag+1), num_nodes, num_nodes) to shape @@ -41,7 +44,7 @@ def convert_adj_to_timelagged(A, lag, n_fragments, aggregated_graph=False, retur assert False, "invalid adjacency matrix" n_fragments_per_sample = n_fragments // n_samples - unique_matrices, matrix_indices = get_adj_matrix_id(A) + _, matrix_indices = get_adj_matrix_id(A) for i in range(n_fragments_per_sample): Ap[i*n_samples:(i+1)*n_samples] = A @@ -49,8 +52,7 @@ def convert_adj_to_timelagged(A, lag, n_fragments, aggregated_graph=False, retur if return_indices: return Ap, A_indices - else: - return Ap + return Ap def to_time_aggregated_graph_np(graph): @@ -58,34 +60,39 @@ def to_time_aggregated_graph_np(graph): # graph of shape [batch, num_nodes, num_nodes] return (np.sum(graph, axis=1) > 0).astype(int) + def to_time_aggregated_scores_np(graph): return np.max(graph, axis=1) + def to_time_aggregated_graph_torch(graph): # convert graph of shape [batch, lag+1, num_nodes, num_nodes] to aggregated # graph of shape [batch, num_nodes, num_nodes] return (torch.sum(graph, dim=1) > 0).long() + def to_time_aggregated_scores_torch(graph): # convert edge probability matrix of shape [batch, lag+1, num_nodes, num_nodes] to aggregated # matrix of shape [batch, num_nodes, num_nodes] max_val, _ = torch.max(graph, dim=1) return max_val + def zero_out_diag_np(G): - - if len(G.shape) == 3: - N = G.shape[1] - I = np.arange(N) - G[:, I, I] = 0 - - elif len(G.shape) == 2: - N = G.shape[0] - I = np.arange(N) - G[I, I] = 0 - - return G - + + if len(G.shape) == 3: + N = G.shape[1] + I = np.arange(N) + G[:, I, I] = 0 + + elif len(G.shape) == 2: + N = G.shape[0] + I = np.arange(N) + G[I, I] = 0 + + return G + + def zero_out_diag_torch(G): if len(G.shape) == 3: @@ -97,5 +104,5 @@ def zero_out_diag_torch(G): N = G.shape[0] I = torch.arange(N) G[I, I] = 0 - - return G \ No newline at end of file + + return G diff --git a/src/utils/data_utils/dataloading_utils.py b/src/utils/data_utils/dataloading_utils.py index 7dea80a..d9a3c64 100644 --- a/src/utils/data_utils/dataloading_utils.py +++ b/src/utils/data_utils/dataloading_utils.py @@ -1,5 +1,6 @@ -import numpy as np import os +import numpy as np + def get_dataset_path(dataset): if 'netsim' in dataset: @@ -8,18 +9,19 @@ def get_dataset_path(dataset): dataset_path = 'dream3' elif 'snp100' in dataset: dataset_path = 'snp100' - elif dataset == 'lorenz96' or dataset == 'finance' or dataset == 'fluxnet': + elif dataset == ['lorenz96', 'finance', 'fluxnet']: dataset_path = dataset else: dataset_path = 'synthetic' - + return dataset_path + def create_save_name(dataset, cfg): if dataset == 'lorenz96': return f'lorenz96_N={cfg.num_nodes}_T={cfg.timesteps}_num_graphs={cfg.num_graphs}' - else: - return dataset + return dataset + def load_synthetic_from_folder(dataset_dir, dataset_name): X = np.load(os.path.join(dataset_dir, dataset_name, 'X.npy')) @@ -28,6 +30,7 @@ def load_synthetic_from_folder(dataset_dir, dataset_name): return X, adj_matrix + def load_netsim(dataset_dir, dataset_file): # load the files data = np.load(os.path.join(dataset_dir, dataset_file + '.npz')) @@ -36,6 +39,7 @@ def load_netsim(dataset_dir, dataset_file): # adj_matrix = np.transpose(adj_matrix, (0, 2, 1)) return X, adj_matrix + def load_dream3_combined(dataset_dir, size): data = np.load(os.path.join(dataset_dir, f'combined_{size}.npz')) X = data['X'] @@ -50,12 +54,13 @@ def load_snp100(dataset, dataset_dir): # get the sector sector = dataset.split('_')[1] X = np.load(os.path.join(dataset_dir, f'X_{sector}.npy')) - + D = X.shape[2] - # we do not have the true adjacency matrix + # we do not have the true adjacency matrix adj_matrix = np.zeros((X.shape[0], D, D)) return X, adj_matrix + def load_data(dataset, dataset_dir, config): if 'netsim' in dataset: X, adj_matrix = load_netsim( @@ -65,13 +70,14 @@ def load_data(dataset, dataset_dir, config): # read lag from config file lag = int(config['lag']) data_dim = 1 - X = np.expand_dims(X, axis=-1) + X = np.expand_dims(X, axis=-1) elif dataset == 'dream3': dream3_size = int(config['dream3_size']) - X, adj_matrix = load_dream3_combined(dataset_dir=dataset_dir, size=dream3_size) + X, adj_matrix = load_dream3_combined( + dataset_dir=dataset_dir, size=dream3_size) lag = int(config['lag']) data_dim = 1 - aggregated_graph=True + aggregated_graph = True X = np.expand_dims(X, axis=-1) elif 'snp100' in dataset: X, adj_matrix = load_snp100(dataset=dataset, dataset_dir=dataset_dir) @@ -87,4 +93,4 @@ def load_data(dataset, dataset_dir, config): X = np.expand_dims(X, axis=-1) aggregated_graph = False print("Loaded data of shape:", X.shape) - return X, adj_matrix, aggregated_graph, lag, data_dim \ No newline at end of file + return X, adj_matrix, aggregated_graph, lag, data_dim diff --git a/src/utils/loss_utils.py b/src/utils/loss_utils.py index 031efbf..e35d837 100644 --- a/src/utils/loss_utils.py +++ b/src/utils/loss_utils.py @@ -1,6 +1,5 @@ import torch -import numpy as np def temporal_graph_sparsity(G: torch.Tensor): """ @@ -13,9 +12,11 @@ def temporal_graph_sparsity(G: torch.Tensor): return torch.sum(torch.square(G)) + def l1_sparsity(G: torch.Tensor): return torch.sum(torch.abs(G)) + def dag_penalty_notears(G: torch.Tensor): """ Implements the DAGness penalty from @@ -29,17 +30,20 @@ def dag_penalty_notears(G: torch.Tensor): if len(G.shape) == 2: trace_term = torch.trace(torch.matrix_exp(G)) - return (trace_term - num_nodes) + return trace_term - num_nodes elif len(G.shape) == 3: trace_term = torch.einsum("ijj->i", torch.matrix_exp(G)) return torch.sum(trace_term - num_nodes) + assert False, "DAG Penalty received illegal shape" + def dag_penalty_notears_sq(W: torch.Tensor): num_nodes = W.shape[-1] if len(W.shape) == 2: trace_term = torch.trace(torch.matrix_exp(W * W)) - return (trace_term - num_nodes) + return trace_term - num_nodes elif len(W.shape) == 3: trace_term = torch.einsum("ijj->i", torch.matrix_exp(W * W)) return torch.sum(trace_term) - W.shape[0] * num_nodes + assert False, "DAG Penalty received illegal shape" diff --git a/src/utils/metrics_utils.py b/src/utils/metrics_utils.py index 529ec99..a891d00 100644 --- a/src/utils/metrics_utils.py +++ b/src/utils/metrics_utils.py @@ -7,11 +7,13 @@ import numpy as np import torch + def get_off_diagonal(A): # assumes A.shape: (batch, x, y) M = np.invert(np.eye(A.shape[1], dtype=bool)) return A[:, M] + def adjacency_f1(adj_matrix, predictions): # adj_matrix: (b, l, d, d) or (b, d, d) # predictions: (b, l, d, d) or (b, d, d) @@ -23,13 +25,15 @@ def adjacency_f1(adj_matrix, predictions): adj_upper = adj_matrix[..., U[0], U[1]] adj_diag = np.diagonal(adj_matrix, axis1=-2, axis2=-1).flatten() - adj = np.concatenate((adj_diag, np.logical_or(adj_lower, adj_upper).flatten().astype(int))) - + adj = np.concatenate((adj_diag, np.logical_or( + adj_lower, adj_upper).flatten().astype(int))) + pred_diag = np.diagonal(predictions, axis1=-2, axis2=-1).flatten() pred_lower = predictions[..., L[0], L[1]] pred_upper = predictions[..., U[0], U[1]] - pred = np.concatenate((pred_diag, np.logical_or(pred_lower, pred_upper).flatten().astype(int))) - + pred = np.concatenate((pred_diag, np.logical_or( + pred_lower, pred_upper).flatten().astype(int))) + return f1_score(adj, pred) @@ -62,52 +66,53 @@ def compute_shd(adj_matrix, preds, aggregated_graph=False): else: shd_lag += shd return shd_score/adj_matrix.shape[0], shd_inst/adj_matrix.shape[0], shd_lag/adj_matrix.shape[0] - else: - for i in range(adj_matrix.shape[0]): - adj_sub_matrix = adj_matrix[i] - preds_sub_matrix = preds[i] - shd_score += SHD(adj_sub_matrix, preds_sub_matrix) - # print(SHD(adj_sub_matrix, preds_sub_matrix)) - return shd_score/adj_matrix.shape[0] + for i in range(adj_matrix.shape[0]): + adj_sub_matrix = adj_matrix[i] + preds_sub_matrix = preds[i] + shd_score += SHD(adj_sub_matrix, preds_sub_matrix) + # print(SHD(adj_sub_matrix, preds_sub_matrix)) + return shd_score/adj_matrix.shape[0] + def calculate_expected_shd(scores, adj_matrix, aggregated_graph=False, n_trials=100): totals_shd = 0 - for i in range(n_trials): + for _ in range(n_trials): draw = np.random.binomial(1, scores) if aggregated_graph: - shd = compute_shd(adj_matrix, draw, aggregated_graph=aggregated_graph) + shd = compute_shd(adj_matrix, draw, + aggregated_graph=aggregated_graph) else: - shd, shd_lag, shd_inst = compute_shd(adj_matrix, draw, aggregated_graph=aggregated_graph) + shd, _, _ = compute_shd( + adj_matrix, draw, aggregated_graph=aggregated_graph) totals_shd += shd - + return totals_shd/n_trials -def evaluate_results(scores, - adj_matrix, - predictions, - aggregated_graph=False, +def evaluate_results(scores, + adj_matrix, + predictions, + aggregated_graph=False, true_cluster_indices=None, pred_cluster_indices=None): - - num_samples = scores.shape[0] + assert adj_matrix.shape == predictions.shape, "Dimension of adj_matrix should match the predictions" abs_scores = np.abs(scores).flatten() preds = np.abs(np.round(predictions)) truth = adj_matrix.flatten() - print(adj_matrix.shape) + # calculate shd - m = adj_matrix.shape[0] if aggregated_graph: shd_score = compute_shd(adj_matrix, preds, aggregated_graph) else: shd_score, shd_inst, shd_lag = compute_shd( adj_matrix, preds, aggregated_graph) - f1_inst = f1_score(get_off_diagonal(adj_matrix[:, 0]).flatten(), get_off_diagonal(predictions[:, 0]).flatten()) + f1_inst = f1_score(get_off_diagonal(adj_matrix[:, 0]).flatten( + ), get_off_diagonal(predictions[:, 0]).flatten()) f1_lag = f1_score(adj_matrix[:, 1:].flatten(), preds[:, 1:].flatten()) - + f1 = f1_score(truth, preds.flatten()) adj_f1 = adjacency_f1(adj_matrix, predictions) @@ -125,7 +130,7 @@ def evaluate_results(scores, rocauc = roc_auc_score(truth, abs_scores) except ValueError: rocauc = 0.5 - + tnr = zero_edge_accuracy tpr = one_edge_accuracy @@ -135,7 +140,7 @@ def evaluate_results(scores, print("Precision score:", precision) print("Recall score:", recall) print("ROC AUC score:", rocauc) - + print("Accuracy on '0' edges", tnr) print("Accuracy on '1' edges", tpr) print("Structural Hamming Distance:", shd_score) @@ -144,7 +149,8 @@ def evaluate_results(scores, print("Structural Hamming Distance (lag):", shd_lag) print("Orientation F1 inst", f1_inst) print("Orientation F1 lag", f1_lag) - eshd = calculate_expected_shd(np.abs(scores/(np.max(scores)+1e-4)), adj_matrix, aggregated_graph) + eshd = calculate_expected_shd( + np.abs(scores/(np.max(scores)+1e-4)), adj_matrix, aggregated_graph) print("Expected SHD:", eshd) # also return a dictionary of metrics metrics = { @@ -168,19 +174,21 @@ def evaluate_results(scores, metrics['f1_lag'] = f1_lag if pred_cluster_indices is not None and true_cluster_indices is not None: - metrics['cluster_acc'] = cluster_accuracy(true_idx=true_cluster_indices, + metrics['cluster_acc'] = cluster_accuracy(true_idx=true_cluster_indices, pred_idx=pred_cluster_indices) else: - _, true_cluster_indices = np.unique(adj_matrix, return_inverse=True, axis=0) - _, pred_cluster_indices = np.unique(predictions, return_inverse=True, axis=0) - metrics['cluster_acc'] = cluster_accuracy(true_idx=true_cluster_indices, + _, true_cluster_indices = np.unique( + adj_matrix, return_inverse=True, axis=0) + _, pred_cluster_indices = np.unique( + predictions, return_inverse=True, axis=0) + metrics['cluster_acc'] = cluster_accuracy(true_idx=true_cluster_indices, pred_idx=pred_cluster_indices) - + return metrics -def mape_loss(X_true, X_pred): - return torch.mean(torch.abs((X_true - X_pred) / X_true))*100 +def mape_loss(X_true, x_pred): + return torch.mean(torch.abs((X_true - x_pred) / X_true))*100 def cluster_accuracy(true_idx, pred_idx): @@ -193,4 +201,4 @@ def cluster_accuracy(true_idx, pred_idx): row_ind, col_ind = linear_sum_assignment(cm, maximize=True) # get the maximum matching - return cm[row_ind, col_ind].sum()/cm.sum() \ No newline at end of file + return cm[row_ind, col_ind].sum()/cm.sum() diff --git a/src/utils/torch_utils.py b/src/utils/torch_utils.py index 69bc8aa..c005adf 100644 --- a/src/utils/torch_utils.py +++ b/src/utils/torch_utils.py @@ -2,10 +2,11 @@ Borrowed from github.com/microsoft/causica """ -from typing import List, Optional, Tuple, Type, Union +from typing import List, Optional, Type from torch.nn import Dropout, LayerNorm, Linear, Module, Sequential import torch + class resBlock(Module): """ Wraps an nn.Module, adding a skip connection to it. @@ -40,7 +41,8 @@ def generate_fully_connected( Args: input_dim: Int. Size of input to network. output_dim: Int. Size of output of network. - hidden_dims: List of int. Sizes of internal hidden layers. i.e. [a, b] is three linear layers with shapes (input_dim, a), (a, b), (b, output_dim) + hidden_dims: List of int. Sizes of internal hidden layers. i.e. + [a, b] is three linear layers with shapes (input_dim, a), (a, b), (b, output_dim) non_linearity: Non linear activation function used between Linear layers. activation: Final layer activation to use. device: torch device to load weights to. diff --git a/src/utils/utils.py b/src/utils/utils.py index 75902e8..16d92fc 100644 --- a/src/utils/utils.py +++ b/src/utils/utils.py @@ -1,9 +1,7 @@ -import numpy as np import os -from matplotlib.animation import FuncAnimation, PillowWriter -import matplotlib.pyplot as plt -import seaborn as sns import csv +import numpy as np + def standard_scaling(X, across_samples=False): # expected X of shape (n_samples, timesteps, num_nodes, data_dim) or (n_samples, timesteps, num_nodes) @@ -31,11 +29,11 @@ def min_max_scaling(X, across_samples=False): mins = np.amin(X, axis=(1))[:, np.newaxis] maxs = np.amax(X, axis=(1))[:, np.newaxis] - eps = 1e-6 Y = (X-mins) / (maxs - mins) * 2 - 1 return Y + def write_results_to_disk(dataset, metrics): # write results to file results_dir = os.path.join('results', dataset) @@ -45,7 +43,7 @@ def write_results_to_disk(dataset, metrics): if not os.path.exists(results_dir): os.makedirs(results_dir) - with open(results_file, 'a') as csvfile: + with open(results_file, 'a', encoding="utf-8") as csvfile: writer = csv.DictWriter(csvfile, fieldnames=list(metrics.keys())) if not file_exists: writer.writeheader()