From dbad7e56c7125593d726cc4158ea0440cd4f96ff Mon Sep 17 00:00:00 2001 From: alitinet Date: Sun, 21 Jul 2024 19:03:43 +0200 Subject: [PATCH 1/2] added docs --- src/multimil/data/_preprocessing.py | 13 +- src/multimil/dataloaders/_ann_dataloader.py | 47 ++-- src/multimil/dataloaders/_data_splitting.py | 19 +- src/multimil/distributions/_mmd.py | 40 ++-- src/multimil/model/_mil.py | 166 +++++++------ src/multimil/model/_multivae.py | 174 ++++++++------ src/multimil/model/_multivae_mil.py | 246 ++++++++++++-------- src/multimil/module/_mil_torch.py | 106 ++++++++- src/multimil/module/_multivae_mil_torch.py | 146 +++++++++++- src/multimil/module/_multivae_torch.py | 209 ++++++++++------- src/multimil/nn/_base_components.py | 130 +++++++++-- src/multimil/utils/_utils.py | 163 ++++++++++++- 12 files changed, 1047 insertions(+), 412 deletions(-) diff --git a/src/multimil/data/_preprocessing.py b/src/multimil/data/_preprocessing.py index 04d8f98..91b6b82 100644 --- a/src/multimil/data/_preprocessing.py +++ b/src/multimil/data/_preprocessing.py @@ -8,7 +8,7 @@ def organize_multiome_anndatas( adatas: list[list[ad.AnnData | None]], layers: list[list[str | None]] | None = None, -): +) -> ad.AnnData: """Concatenate all the input anndata objects. These anndata objects should already have been preprocessed so that all single-modality @@ -16,11 +16,16 @@ def organize_multiome_anndatas( `.var`) should match between the objects for vertical integration and cell names (index of `.obs`) should match between the objects for horizontal integration. - :param adatas: - List of Lists with AnnData objects or None where each sublist corresponds to a modality - :param layers: + Parameters + ---------- + adatas + List of Lists with AnnData objects or None where each sublist corresponds to a modality. + layers List of Lists of the same lengths as `adatas` specifying which `.layer` to use for each AnnData. Default is None which means using `.X`. + Returns + ------- + Concatenated AnnData object across modalities and datasets. """ # TODO: add checks for layers # TODO: add check that len of modalities is the same as len of losses, etc diff --git a/src/multimil/dataloaders/_ann_dataloader.py b/src/multimil/dataloaders/_ann_dataloader.py index ce988c1..05b4563 100644 --- a/src/multimil/dataloaders/_ann_dataloader.py +++ b/src/multimil/dataloaders/_ann_dataloader.py @@ -14,16 +14,18 @@ class StratifiedSampler(torch.utils.data.sampler.Sampler): """Custom stratified sampler class which enables sampling the same number of observation from each group in each mini-batch. - :param indices: - list of indices to sample from - :param batch_size: - batch size of each iteration - :param shuffle: - if ``True``, shuffles indices before sampling - :param drop_last: - if int, drops the last batch if its length is less than drop_last. - if drop_last == True, drops last non-full batch. - if drop_last == False, iterate over all batches. + Parameters + ---------- + indices + List of indices to sample from. + batch_size + Batch size of each iteration. + shuffle + If ``True``, shuffles indices before sampling. + drop_last + If int, drops the last batch if its length is less than drop_last. + If drop_last == True, drops last non-full batch. + If drop_last == False, iterate over all batches. """ def __init__( @@ -139,23 +141,24 @@ def __len__(self): # https://github.com/scverse/scvi-tools/blob/0b802762869c43c9f49e69fe62b1a5a9b5c4dae6/scvi/dataloaders/_ann_dataloader.py#L89 # accessed on 5 November 2022 class GroupAnnDataLoader(DataLoader): - """ - DataLoader for loading tensors from AnnData objects. + """DataLoader for loading tensors from AnnData objects. - :param adata_manager: + Parameters + ---------- + adata_manager :class:`~scvi.data.AnnDataManager` object with a registered AnnData object. - :param shuffle: - Whether the data should be shuffled - :param indices: - The indices of the observations in the adata to load - :param batch_size: - minibatch size to load each iteration - :param data_and_attributes: + shuffle + Whether the data should be shuffled. + indices + The indices of the observations in the adata to load. + batch_size + Minibatch size to load each iteration. + data_and_attributes Dictionary with keys representing keys in data registry (`adata.uns["_scvi"]`) and value equal to desired numpy loading type (later made into torch tensor). If `None`, defaults to all registered data. - :param data_loader_kwargs: - Keyword arguments for :class:`~torch.utils.data.DataLoader` + data_loader_kwargs + Keyword arguments for :class:`~torch.utils.data.DataLoader`. """ def __init__( diff --git a/src/multimil/dataloaders/_data_splitting.py b/src/multimil/dataloaders/_data_splitting.py index 0c9e2b4..f2f6822 100644 --- a/src/multimil/dataloaders/_data_splitting.py +++ b/src/multimil/dataloaders/_data_splitting.py @@ -8,21 +8,22 @@ # https://github.com/scverse/scvi-tools/blob/0b802762869c43c9f49e69fe62b1a5a9b5c4dae6/scvi/dataloaders/_data_splitting.py#L56 # accessed on 5 November 2022 class GroupDataSplitter(DataSplitter): - """ - Creates data loaders ``train_set``, ``validation_set``, ``test_set``. + """Creates data loaders ``train_set``, ``validation_set``, ``test_set``. If ``train_size + validation_set < 1`` then ``test_set`` is non-empty. - :param adata_manager: + Parameters + ---------- + adata_manager :class:`~scvi.data.AnnDataManager` object that has been created via ``setup_anndata``. - :param train_size: - float, or None (default is 0.9) - :param validation_size: - float, or None (default is None) - :param use_gpu: + train_size + Proportion of cells to use as the train set. Float, or None (default is 0.9). + validation_size + Proportion of cell to use as the valisation set. Float, or None (default is None). If None, is set to 1 - ``train_size``. + use_gpu Use default GPU if available (if None or True), or index of GPU to use (if int), or name of GPU (if str, e.g., `'cuda:0'`), or use CPU (if False). - :param kwargs: + kwargs Keyword args for data loader. Data loader class is :class:`~mtg.dataloaders.GroupAnnDataLoader`. """ diff --git a/src/multimil/distributions/_mmd.py b/src/multimil/distributions/_mmd.py index f879c98..acfa2f8 100644 --- a/src/multimil/distributions/_mmd.py +++ b/src/multimil/distributions/_mmd.py @@ -4,10 +4,12 @@ class MMD(torch.nn.Module): """Maximum mean discrepancy. - :param kernel_type: + Parameters + ---------- + kernel_type Indicates if to use Gaussian kernel. One of * ``'gaussian'`` - use Gaussian kernel - * ``'not gaussian'`` - do not use Gaussian kernel + * ``'not gaussian'`` - do not use Gaussian kernel. """ def __init__(self, kernel_type="gaussian"): @@ -23,12 +25,18 @@ def gaussian_kernel( ) -> torch.Tensor: """Apply Guassian kernel. - :param x: - Tensor from the first distribution - :param y: - Tensor from the second distribution - :param gamma: - List of gamma parameters + Parameters + ---------- + x + Tensor from the first distribution. + y + Tensor from the second distribution. + gamma + List of gamma parameters. + + Returns + ------- + Gaussian kernel between ``x`` and ``y``. """ if gamma is None: gamma = [ @@ -71,12 +79,16 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: Availability: https://github.com/theislab/scarches/blob/63a7c2b35a01e55fe7e1dd871add459a86cd27fb/scarches/models/trvae/losses.py Citation: Gretton, Arthur, et al. "A Kernel Two-Sample Test", 2012. - :param x: - Tensor with shape ``(batch_size, z_dim)`` - :param y: - Tensor with shape ``(batch_size, z_dim)`` - :returns: - MMD between ``x`` and ``y`` + Parameters + ---------- + x + Tensor with shape ``(batch_size, z_dim)``. + y + Tensor with shape ``(batch_size, z_dim)``. + + Returns + ------- + MMD between ``x`` and ``y``. """ # in case there is only one sample in a batch belonging to one of the groups, then skip the batch if len(x) == 1 or len(y) == 1: diff --git a/src/multimil/model/_mil.py b/src/multimil/model/_mil.py index b7e2fef..e58165a 100644 --- a/src/multimil/model/_mil.py +++ b/src/multimil/model/_mil.py @@ -31,6 +31,62 @@ class MILClassifier(BaseModelClass, ArchesMixin): + """MultiMIL MIL prediction model. + + Parameters + ---------- + adata + AnnData object containing embeddings and covariates. + sample_key + Key in `adata.obs` that corresponds to the sample covariate. + classification + List of keys in `adata.obs` that correspond to the classification covariates. + regression + List of keys in `adata.obs` that correspond to the regression covariates. + ordinal_regression + List of keys in `adata.obs` that correspond to the ordinal regression covariates. + sample_batch_size + Number of samples per bag, i.e. sample. Default is 128. + normalization + One of "layer" or "batch". Default is "layer". + z_dim + Dimensionality of the input latent space. Default is 16. + dropout + Dropout rate. Default is 0.2. + scoring + How to calculate attention scores. One of "gated_attn", "MLP". Default is "gated_attn". + attn_dim + Dimensionality of the hidden layer in attention calculation. Default is 16. + n_layers_cell_aggregator + Number of layers in the cell aggregator. Default is 1. + n_layers_classifier + Number of layers in the classifier. Default is 2. + n_layers_regressor + Number of layers in the regressor. Default is 2. + n_layers_mlp_attn + Number of layers in the MLP attention. Only used if `scoring` = "MLP". Default is 1. + n_hidden_cell_aggregator + Number of hidden units in the cell aggregator. Default is 128. + n_hidden_classifier + Number of hidden units in the classifier. Default is 128. + n_hidden_mlp_attn + Number of hidden units in the MLP attention. Default is 32. + n_hidden_regressor + Number of hidden units in the regressor. Default is 128. + class_loss_coef + Coefficient for the classification loss. Default is 1.0. + regression_loss_coef + Coefficient for the regression loss. Default is 1.0. + activation + Activation function. Default is 'leaky_relu'. + initialization + Initialization method for the weights. Default is None. + anneal_class_loss + Whether to anneal the classification loss. Default is False. + ignore_covariates + List of covariates to ignore. Needed for query-to-reference mapping. Default is None. + """ + def __init__( self, adata, @@ -59,59 +115,6 @@ def __init__( anneal_class_loss=False, ignore_covariates=None, ): - """MultiMIL MIL prediction model. - - :param adata: - AnnData object containing embeddings and covariates. - :param sample_key: - Key in `adata.obs` that corresponds to the sample covariate. - :param classification: - List of keys in `adata.obs` that correspond to the classification covariates. - :param regression: - List of keys in `adata.obs` that correspond to the regression covariates. - :param ordinal_regression: - List of keys in `adata.obs` that correspond to the ordinal regression covariates. - :param sample_batch_size: - Number of samples per bag, i.e. sample. Default is 128. - :param normalization: - One of "layer" or "batch". Default is "layer". - :param z_dim: - Dimensionality of the input latent space. Default is 16. - :param dropout: - Dropout rate. Default is 0.2. - :param scoring: - How to calculate attention scores. One of "gated_attn", "MLP". Default is "gated_attn". - :param attn_dim: - Dimensionality of the hidden layer in attention calculation. Default is 16. - :param n_layers_cell_aggregator: - Number of layers in the cell aggregator. Default is 1. - :param n_layers_classifier: - Number of layers in the classifier. Default is 2. - :param n_layers_regressor: - Number of layers in the regressor. Default is 2. - :param n_layers_mlp_attn: - Number of layers in the MLP attention. Only used if `scoring` = "MLP". Default is 1. - :param n_hidden_cell_aggregator: - Number of hidden units in the cell aggregator. Default is 128. - :param n_hidden_classifier: - Number of hidden units in the classifier. Default is 128. - :param n_hidden_mlp_attn: - Number of hidden units in the MLP attention. Default is 32. - :param n_hidden_regressor: - Number of hidden units in the regressor. Default is 128. - :param class_loss_coef: - Coefficient for the classification loss. Default is 1.0. - :param regression_loss_coef: - Coefficient for the regression loss. Default is 1.0. - :param activation: - Activation function. Default is 'leaky_relu'. - :param initialization: - Initialization method for the weights. Default is None. - :param anneal_class_loss: - Whether to anneal the classification loss. Default is False. - :param ignore_covariates: - List of covariates to ignore. Needed for query-to-reference mapping. Default is None. - """ super().__init__(adata) if classification is None: @@ -246,8 +249,7 @@ def train( path_to_checkpoints: str | None = None, **kwargs, ): - """ - Trains the model using amortized variational inference. + """Trains the model using amortized variational inference. Parameters ---------- @@ -299,6 +301,10 @@ def train( Path to save checkpoints. **kwargs Other keyword args for :class:`~scvi.train.Trainer`. + + Returns + ------- + Trainer object. """ # TODO put in a function, return params needed for splitter, plan and runner, then can call the function from multivae_mil if len(self.regression) > 0: @@ -389,16 +395,18 @@ def setup_anndata( This method will also compute the log mean and log variance per batch for the library size prior. None of the data in adata are modified. Only adds fields to adata. - :param adata: - AnnData object containing raw counts. Rows represent cells, columns represent features - :param categorical_covariate_keys: - Keys in `adata.obs` that correspond to categorical data - :param continuous_covariate_keys: - Keys in `adata.obs` that correspond to continuous data - :param ordinal_regression_order: - Dictionary with regression classes as keys and order of classes as values - :param kwargs: - Additional parameters to pass to register_fields() of AnnDataManager + Parameters + ---------- + adata + AnnData object containing raw counts. Rows represent cells, columns represent features. + categorical_covariate_keys + Keys in `adata.obs` that correspond to categorical data. + continuous_covariate_keys + Keys in `adata.obs` that correspond to continuous data. + ordinal_regression_order + Dictionary with regression classes as keys and order of classes as values. + kwargs + Additional parameters to pass to register_fields() of AnnDataManager. """ setup_ordinal_regression(adata, ordinal_regression_order, categorical_covariate_keys) @@ -426,9 +434,11 @@ def get_model_output( ): """Save the attention scores and predictions in the adata object. - :param adata: + Parameters + ---------- + adata AnnData object to run the model on. If `None`, the model's AnnData object is used. - :param batch_size: + batch_size Minibatch size to use. Default is 256. """ @@ -552,7 +562,9 @@ def get_model_output( def plot_losses(self, save=None): """Plot losses. - :param save: + Parameters + ---------- + save If not None, save the plot to this location. """ loss_names = self.module.select_losses_to_plot() @@ -567,18 +579,24 @@ def load_query_data( adata: AnnData, reference_model: BaseModelClass, use_gpu: str | int | bool | None = None, - ): + ) -> BaseModelClass: """Online update of a reference model with scArches algorithm # TODO cite. - :param adata: + Parameters + ---------- + adata AnnData organized in the same way as data used to train model. It is not necessary to run setup_anndata, as AnnData is validated against the ``registry``. - :param reference_model: - Already instantiated model of the same class - :param use_gpu: + reference_model + Already instantiated model of the same class. + use_gpu Load model on default GPU if available (if None or True), - or index of GPU to use (if int), or name of GPU (if str), or use CPU (if False) + or index of GPU to use (if int), or name of GPU (if str), or use CPU (if False). + + Returns + ------- + Model with updated architecture and weights. """ # currently this function works only if the prediction cov is present in the .obs of the query # TODO need to allow it to be missing, maybe add a dummy column to .obs of query adata diff --git a/src/multimil/model/_multivae.py b/src/multimil/model/_multivae.py index e7ac2cf..3583229 100644 --- a/src/multimil/model/_multivae.py +++ b/src/multimil/model/_multivae.py @@ -25,49 +25,51 @@ class MultiVAE(BaseModelClass, ArchesMixin): """MultiMIL multimodal integration model. - :param adata: + Parameters + ---------- + adata AnnData object that has been registered via :meth:`~multigrate.model.MultiVAE.setup_anndata`. - :param integrate_on: + integrate_on One of the categorical covariates refistered with :math:`~multigrate.model.MultiVAE.setup_anndata` to integrate on. The latent space then will be disentangled from this covariate. If `None`, no integration is performed. - :param condition_encoders: + condition_encoders Whether to concatentate covariate embeddings to the first layer of the encoders. Default is `False`. - :param condition_decoders: + condition_decoders Whether to concatentate covariate embeddings to the first layer of the decoders. Default is `True`. - :param normalization: + normalization What normalization to use; has to be one of `batch` or `layer`. Default is `layer`. - :param z_dim: + z_dim Dimensionality of the latent space. Default is 16. - :param losses: + losses Which losses to use for each modality. Has to be the same length as the number of modalities. Default is `MSE` for all modalities. - :param dropout: + dropout Dropout rate. Default is 0.2. - :param cond_dim: + cond_dim Dimensionality of the covariate embeddings. Default is 16. - :param kernel_type: + kernel_type Type of kernel to use for the MMD loss. Default is `gaussian`. - :param loss_coefs: + loss_coefs Loss coeficients for the different losses in the model. Default is 1 for all. - :param cont_cov_type: + cont_cov_type How to calculate embeddings for continuous covariates. Default is `logsim`. - :param n_layers_cont_embed: + n_layers_cont_embed Number of layers for the continuous covariate embedding calculation. Default is 1. - :param n_layers_encoders: + n_layers_encoders Number of layers for each encoder. Default is 2 for all modalities. Has to be the same length as the number of modalities. - :param n_layers_decoders: + n_layers_decoders Number of layers for each decoder. Default is 2 for all modalities. Has to be the same length as the number of modalities. - :param n_hidden_cont_embed: + n_hidden_cont_embed Number of nodes for each hidden layer in the continuous covariate embedding calculation. Default is 32. - :param n_hidden_encoders: + n_hidden_encoders Number of nodes for each hidden layer in the encoders. Default is 32. - :param n_hidden_decoders: + n_hidden_decoders Number of nodes for each hidden layer in the decoders. Default is 32. - :param mmd: + mmd Which MMD loss to use. Default is `latent`. - :param activation: + activation Activation function to use. Default is `leaky_relu`. - :param initialization: + initialization Initialization method to use. Default is `None`. - :param ignore_covariates: + ignore_covariates List of covariates to ignore. Needed for query-to-reference mapping. Default is `None`. """ @@ -193,9 +195,11 @@ def __init__( def get_model_output(self, adata=None, batch_size=256): """Save the latent representation in the adata object. - :param adata: + Parameters + ---------- + adata AnnData object to run the model on. If `None`, the model's AnnData object is used. - :param batch_size: + batch_size Minibatch size to use. Default is 256. """ if not self.is_trained_: @@ -237,50 +241,56 @@ def train( ): """Train the model using amortized variational inference. - :param max_epochs: - Number of passes through the dataset - :param lr: - Learning rate for optimization - :param use_gpu: + Parameters + ---------- + max_epochs + Number of passes through the dataset. + lr + Learning rate for optimization. + use_gpu Use default GPU if available (if None or True), or index of GPU to use (if int), - or name of GPU (if str), or use CPU (if False) - :param train_size: - Size of training set in the range [0.0, 1.0] - :param validation_size: + or name of GPU (if str), or use CPU (if False). + train_size + Size of training set in the range [0.0, 1.0]. + validation_size Size of the test set. If `None`, defaults to 1 - `train_size`. If - `train_size + validation_size < 1`, the remaining cells belong to a test set - :param batch_size: - Minibatch size to use during training - :param weight_decay: - Weight decay regularization term for optimization - :param eps: - Optimizer eps - :param early_stopping: - Whether to perform early stopping with respect to the validation set - :param save_best: + `train_size + validation_size < 1`, the remaining cells belong to a test set. + batch_size + Minibatch size to use during training. + weight_decay + Weight decay regularization term for optimization. + eps + Optimizer eps. + early_stopping + Whether to perform early stopping with respect to the validation set. + save_best Save the best model state with respect to the validation loss, or use the final - state in the training procedure - :param check_val_every_n_epoch: + state in the training procedure. + check_val_every_n_epoch Check val every n train epochs. By default, val is not checked, unless `early_stopping` is `True`. - If so, val is checked every epoch - :param n_epochs_kl_warmup: + If so, val is checked every epoch. + n_epochs_kl_warmup Number of epochs to scale weight on KL divergences from 0 to 1. Overrides `n_steps_kl_warmup` when both are not `None`. Default is 1/3 of `max_epochs`. - :param n_steps_kl_warmup: + n_steps_kl_warmup Number of training steps (minibatches) to scale weight on KL divergences from 0 to 1. Only activated when `n_epochs_kl_warmup` is set to None. If `None`, defaults - to `floor(0.75 * adata.n_obs)` - :param adversarial_mixing: + to `floor(0.75 * adata.n_obs)`. + adversarial_mixing Whether to use adversarial mixing. Default is `False`. - :param plan_kwargs: + plan_kwargs Keyword args for :class:`~scvi.train.TrainingPlan`. Keyword arguments passed to - `train()` will overwrite values present in `plan_kwargs`, when appropriate - :param save_checkpoint_every_n_epochs: + `train()` will overwrite values present in `plan_kwargs`, when appropriate. + save_checkpoint_every_n_epochs Save a checkpoint every n epochs. If `None`, no checkpoints are saved. - :param path_to_checkpoints: - Path to save checkpoints. Required if `save_checkpoint_every_n_epochs` is not `None` - :param kwargs: - Additional keyword arguments for :class:`~scvi.train.TrainRunner` + path_to_checkpoints + Path to save checkpoints. Required if `save_checkpoint_every_n_epochs` is not `None`. + kwargs + Additional keyword arguments for :class:`~scvi.train.TrainRunner`. + + Returns + ------- + Trainer object. """ if n_epochs_kl_warmup is None: n_epochs_kl_warmup = max(max_epochs // 3, 1) @@ -369,18 +379,20 @@ def setup_anndata( This method will also compute the log mean and log variance per batch for the library size prior. None of the data in adata are modified. Only adds fields to adata. - :param adata: - AnnData object containing raw counts. Rows represent cells, columns represent features - :param size_factor_key: + Parameters + ---------- + adata + AnnData object containing raw counts. Rows represent cells, columns represent features. + size_factor_key Key in `adata.obs` containing the size factor. If `None`, will be calculated from the RNA counts. - :param rna_indices_end: + rna_indices_end Integer to indicate where RNA feature end in the AnnData object. Is used to calculate ``libary_size``. - :param categorical_covariate_keys: - Keys in `adata.obs` that correspond to categorical data - :param continuous_covariate_keys: - Keys in `adata.obs` that correspond to continuous data - :param kwargs: - Additional parameters to pass to register_fields() of AnnDataManager + categorical_covariate_keys + Keys in `adata.obs` that correspond to categorical data. + continuous_covariate_keys + Keys in `adata.obs` that correspond to continuous data. + kwargs + Additional parameters to pass to register_fields() of AnnDataManager. """ setup_method_args = cls._get_setup_method_args(**locals()) @@ -403,7 +415,9 @@ def setup_anndata( def plot_losses(self, save=None): """Plot losses. - :param save: + Parameters + ---------- + save If not None, save the plot to this location. """ loss_names = self.module.select_losses_to_plot() @@ -420,22 +434,28 @@ def load_query_data( use_gpu: str | int | bool | None = None, freeze: bool = True, ignore_covariates: list[str] | None = None, - ): + ) -> BaseModelClass: """Online update of a reference model with scArches algorithm # TODO cite. - :param adata: + Parameters + ---------- + adata AnnData organized in the same way as data used to train model. It is not necessary to run setup_anndata, as AnnData is validated against the ``registry``. - :param reference_model: - Already instantiated model of the same class - :param use_gpu: + reference_model + Already instantiated model of the same class. + use_gpu Load model on default GPU if available (if None or True), - or index of GPU to use (if int), or name of GPU (if str), or use CPU (if False) - :param freeze: - Whether to freeze the encoders and decoders and only train the new weights - :param ignore_covariates: + or index of GPU to use (if int), or name of GPU (if str), or use CPU (if False). + freeze + Whether to freeze the encoders and decoders and only train the new weights. + ignore_covariates List of covariates to ignore. Needed for query-to-reference mapping. Default is `None`. + + Returns + ------- + Model with updated architecture and weights. """ _, _, device = parse_use_gpu_arg(use_gpu) diff --git a/src/multimil/model/_multivae_mil.py b/src/multimil/model/_multivae_mil.py index 3915da5..2b87d93 100644 --- a/src/multimil/model/_multivae_mil.py +++ b/src/multimil/model/_multivae_mil.py @@ -33,6 +33,92 @@ class MultiVAE_MIL(BaseModelClass, ArchesMixin): + """MultiVAE_MIL model. + + Parameters + ---------- + adata + Annotated data object. + sample_key + Key for the sample column in the adata object. + classification + List of keys for the categorical covariates used for classification. + regression + List of keys for the continuous covariates used for regression. + ordinal_regression + List of keys for the ordinal covariates used for ordinal regression. + sample_batch_size + Bag batch size for training the model. + integrate_on + Key for the covariate used for integration. + condition_encoders + Whether to condition the encoders on the covariates. + condition_decoders + Whether to condition the decoders on the covariates. + normalization + Type of normalization to be applied. + n_layers_encoders + Number of layers in the encoders. + n_layers_decoders + Number of layers in the decoders. + n_hidden_encoders + Number of hidden units in the encoders. + n_hidden_decoders + Number of hidden units in the decoders. + z_dim + Dimensionality of the latent space. + losses + List of loss functions to be used. + dropout + Dropout rate. + cond_dim + Dimensionality of the covariate embeddings. + kernel_type + Type of kernel to be used in MMD calculation. + loss_coefs + List of coefficients for different losses. + scoring + Scoring method for the MIL classifier. + attn_dim + Dimensionality of the hidden dimentino in the attention mechanism. + n_layers_cell_aggregator + Number of layers in the cell aggregator. + n_layers_classifier + Number of layers in the classifier. + n_layers_regressor + Number of layers in the regressor. + n_layers_mlp_att + Number of layers in the MLP attention mechanism. + n_layers_cont_embed + Number of layers in the continuous embedding. + n_hidden_cell_aggregator + Number of hidden units in the cell aggregator. + n_hidden_classifier + Number of hidden units in the classifier. + n_hidden_cont_embed + Number of hidden units in the continuous embedding. + n_hidden_mlp_attn + Number of hidden units in the MLP attention mechanism. + n_hidden_regressor + Number of hidden units in the regressor. + class_loss_coef + Coefficient for the classification loss. + regression_loss_coef + Coefficient for the regression loss. + cont_cov_type + How to calucate the embeddings for the continuous covariates. + mmd + Type of maximum mean discrepancy. + sample_in_vae + Whether to include the sample key in the VAE as a covariate. + activation + Activation function to be used. + initialization + Initialization method for the weights. + anneal_class_loss + Whether to anneal the classification loss. + """ + def __init__( self, adata, @@ -76,50 +162,6 @@ def __init__( initialization="kaiming", # xavier (tanh) or kaiming (leaky_relu) anneal_class_loss=False, ): - """ - Initialize the MultiVAE_MIL model. - - :param adata: Annotated data object. - :param sample_key: Key for the sample column in the adata object. - :param classification: List of keys for the categorical covariates used for classification. - :param regression: List of keys for the continuous covariates used for regression. - :param ordinal_regression: List of keys for the ordinal covariates used for ordinal regression. - :param sample_batch_size: Batch size for training the model. - :param integrate_on: Key for the covariate used for integration. - :param condition_encoders: Whether to condition the encoders on the covariates. - :param condition_decoders: Whether to condition the decoders on the covariates. - :param normalization: Type of normalization to be applied. - :param n_layers_encoders: Number of layers in the encoders. - :param n_layers_decoders: Number of layers in the decoders. - :param n_hidden_encoders: Number of hidden units in the encoders. - :param n_hidden_decoders: Number of hidden units in the decoders. - :param z_dim: Dimensionality of the latent space. - :param losses: List of loss functions to be used. - :param dropout: Dropout rate. - :param cond_dim: Dimensionality of the conditional covariates. - :param kernel_type: Type of kernel to be used. - :param loss_coefs: List of coefficients for the loss functions. - :param scoring: Scoring method for the MIL classifier. - :param attn_dim: Dimensionality of the attention mechanism. - :param n_layers_cell_aggregator: Number of layers in the cell aggregator. - :param n_layers_classifier: Number of layers in the classifier. - :param n_layers_regressor: Number of layers in the regressor. - :param n_layers_mlp_attn: Number of layers in the MLP attention mechanism. - :param n_layers_cont_embed: Number of layers in the continuous embedding. - :param n_hidden_cell_aggregator: Number of hidden units in the cell aggregator. - :param n_hidden_classifier: Number of hidden units in the classifier. - :param n_hidden_cont_embed: Number of hidden units in the continuous embedding. - :param n_hidden_mlp_attn: Number of hidden units in the MLP attention mechanism. - :param n_hidden_regressor: Number of hidden units in the regressor. - :param class_loss_coef: Coefficient for the classification loss. - :param regression_loss_coef: Coefficient for the regression loss. - :param cont_cov_type: Type of continuous covariate. - :param mmd: Type of maximum mean discrepancy. - :param sample_in_vae: Whether to include the sample key in the VAE. - :param activation: Activation function to be used. - :param initialization: Initialization method for the model. - :param anneal_class_loss: Whether to anneal the classification loss. - """ super().__init__(adata) if classification is None: @@ -303,8 +345,7 @@ def train( path_to_checkpoints: str | None = None, **kwargs, ): - """ - Trains the model using amortized variational inference. + """Trains the model. Parameters ---------- @@ -356,6 +397,10 @@ def train( Path to save checkpoints. **kwargs Other keyword args for :class:`~scvi.train.Trainer`. + + Returns + ------- + Trainer object. """ if len(self.mil.regression) > 0: if early_stopping_monitor == "accuracy_validation": @@ -445,20 +490,22 @@ def setup_anndata( This method will also compute the log mean and log variance per batch for the library size prior. None of the data in adata are modified. Only adds fields to adata. - :param adata: - AnnData object containing raw counts. Rows represent cells, columns represent features - :param size_factor_key: + Parameters + ---------- + adata + AnnData object containing raw counts. Rows represent cells, columns represent features. + size_factor_key Key in `adata.obs` containing the size factor. If `None`, will be calculated from the RNA counts. - :param rna_indices_end: + rna_indices_end Integer to indicate where RNA feature end in the AnnData object. May be needed to calculate ``libary_size``. - :param categorical_covariate_keys: - Keys in `adata.obs` that correspond to categorical data - :param continuous_covariate_keys: - Keys in `adata.obs` that correspond to continuous data - :param ordinal_regression_order: - Dictionary with regression classes as keys and order of classes as values - :param kwargs: - Additional parameters to pass to register_fields() of AnnDataManager + categorical_covariate_keys + Keys in `adata.obs` that correspond to categorical data. + continuous_covariate_keys + Keys in `adata.obs` that correspond to continuous data. + ordinal_regression_order + Dictionary with regression classes as keys and order of classes as values. + kwargs + Additional parameters to pass to register_fields() of AnnDataManager. """ setup_ordinal_regression(adata, ordinal_regression_order, categorical_covariate_keys) @@ -489,9 +536,11 @@ def get_model_output( ): """Save the latent representation, attention scores and predictions in the adata object. - :param adata: + Parameters + ---------- + adata AnnData object to run the model on. If `None`, the model's AnnData object is used. - :param batch_size: + batch_size Minibatch size to use. Default is 256. """ @@ -637,7 +686,9 @@ def get_model_output( def plot_losses(self, save=None): """Plot losses. - :param save: + Parameters + ---------- + save If not None, save the plot to this location. """ loss_names = [] @@ -656,22 +707,28 @@ def load_query_data( use_gpu: str | int | bool | None = None, freeze: bool = True, ignore_covariates: list[str] | None = None, - ): + ) -> BaseModelClass: """Online update of a reference model with scArches algorithm # TODO cite. - :param adata: + Parameters + ---------- + adata AnnData organized in the same way as data used to train model. It is not necessary to run setup_anndata, as AnnData is validated against the ``registry``. - :param reference_model: - Already instantiated model of the same class - :param use_gpu: + reference_model + Already instantiated model of the same class. + use_gpu Load model on default GPU if available (if None or True), - or index of GPU to use (if int), or name of GPU (if str), or use CPU (if False) - :param freeze: - Whether to freeze the encoders and decoders and only train the new weights - :param ignore_covariates: + or index of GPU to use (if int), or name of GPU (if str), or use CPU (if False). + freeze + Whether to freeze the encoders and decoders and only train the new weights. + ignore_covariates List of covariates to ignore. Needed for query-to-reference mapping. Default is `None`. + + Returns + ------- + Model with updated architecture and weights. """ _, _, device = parse_use_gpu_arg(use_gpu) @@ -762,55 +819,56 @@ def train_vae( ): """Train the VAE part of the model. - :param max_epochs: + Parameters + ---------- + max_epochs Number of passes through the dataset. - :param lr: + lr Learning rate for optimization. - :param use_gpu: + use_gpu Use default GPU if available (if None or True), or index of GPU to use (if int), or name of GPU (if str), or use CPU (if False). - :param train_size: + train_size Size of training set in the range [0.0, 1.0]. - :param validation_size: + validation_size Size of the test set. If `None`, defaults to 1 - `train_size`. If `train_size + validation_size < 1`, the remaining cells belong to a test set. - :param batch_size: + batch_size Minibatch size to use during training. - :param weight_decay: + weight_decay weight decay regularization term for optimization - :param eps: - Optimizer eps - :param early_stopping: + eps + Optimizer eps. + early_stopping Whether to perform early stopping with respect to the validation set. - :param save_best: + save_best Save the best model state with respect to the validation loss, or use the final - state in the training procedure - :param check_val_every_n_epoch: + state in the training procedure. + check_val_every_n_epoch Check val every n train epochs. By default, val is not checked, unless `early_stopping` is `True`. If so, val is checked every epoch. - :param n_epochs_kl_warmup: + n_epochs_kl_warmup Number of epochs to scale weight on KL divergences from 0 to 1. Overrides `n_steps_kl_warmup` when both are not `None`. Default is 1/3 of `max_epochs`. - :param n_steps_kl_warmup: + n_steps_kl_warmup Number of training steps (minibatches) to scale weight on KL divergences from 0 to 1. Only activated when `n_epochs_kl_warmup` is set to None. If `None`, defaults - to `floor(0.75 * adata.n_obs)` - :param adversarial_mixing: + to `floor(0.75 * adata.n_obs)`. + adversarial_mixing Whether to use adversarial mixing in the training procedure. - :param plan_kwargs: + plan_kwargs Keyword args for :class:`~scvi.train.TrainingPlan`. Keyword arguments passed to `train()` will overwrite values present in `plan_kwargs`, when appropriate. - :param plot_losses: + plot_losses Whether to plot the losses. - :param save_loss: + save_loss If not None, save the plot to this location. - :param save_checkpoint_every_n_epochs: + save_checkpoint_every_n_epochs Save a checkpoint every n epochs. - :param path_to_checkpoints: + path_to_checkpoints Path to save checkpoints. - :param kwargs: + kwargs Other keyword args for :class:`~scvi.train.Trainer`. - """ # TODO add a check if there are any new params added in load_query_data, i.e. if there are any new params that can be trained vae = self.multivae diff --git a/src/multimil/module/_mil_torch.py b/src/multimil/module/_mil_torch.py index fd837bb..a5386bb 100644 --- a/src/multimil/module/_mil_torch.py +++ b/src/multimil/module/_mil_torch.py @@ -9,6 +9,58 @@ class MILClassifierTorch(BaseModuleClass): + """MultiMIL's MIL classification module. + + Parameters + ---------- + z_dim + Latent dimension. + dropout + Dropout rate. + normalization + Normalization type. + num_classification_classes + Number of classes for each of the classification task. + scoring + Scoring type. One of ["gated_attn", "attn", "mlp"]. + attn_dim + Hidden attention dimension. + n_layers_cell_aggregator + Number of layers in the cell aggregator. + n_layers_classifier + Number of layers in the classifier. + n_layers_mlp_attn + Number of layers in the MLP attention. + n_layers_regressor + Number of layers in the regressor. + n_hidden_regressor + Hidden dimension in the regressor. + n_hidden_cell_aggregator + Hidden dimension in the cell aggregator. + n_hidden_classifier + Hidden dimension in the classifier. + n_hidden_mlp_attn + Hidden dimension in the MLP attention. + class_loss_coef + Classification loss coefficient. + regression_loss_coef + Regression loss coefficient. + sample_batch_size + Sample batch size. + class_idx + Which indices in cat covariates to do classification on. + ord_idx + Which indices in cat covariates to do ordinal regression on. + reg_idx + Which indices in cont covariates to do regression on. + activation + Activation function. + initialization + Initialization type. + anneal_class_loss + Whether to anneal the classification loss. + """ + def __init__( self, z_dim=16, @@ -143,7 +195,18 @@ def _get_generative_input(self, tensors, inference_outputs): return {"z_joint": z_joint} @auto_move_data - def inference(self, x): + def inference(self, x) -> dict[str, torch.Tensor | list[torch.Tensor]]: + """Forward pass for inference. + + Parameters + ---------- + x + Input. + + Returns + ------- + Predictions. + """ z_joint = x inference_outputs = {"z_joint": z_joint} @@ -171,7 +234,19 @@ def inference(self, x): return inference_outputs # z_joint, mu, logvar, predictions @auto_move_data - def generative(self, z_joint): + def generative(self, z_joint) -> torch.Tensor: + # TODO even if not used, make consistent with the rest, i.e. return dict + """Forward pass for generative. + + Parameters + ---------- + z_joint + Latent embeddings. + + Returns + ------- + Same as input. + """ return z_joint def _calculate_loss(self, tensors, inference_outputs, generative_outputs, kl_weight: float = 1.0): @@ -246,6 +321,23 @@ def _calculate_loss(self, tensors, inference_outputs, generative_outputs, kl_wei return loss, recon_loss, kl_loss, extra_metrics def loss(self, tensors, inference_outputs, generative_outputs, kl_weight: float = 1.0): + """Loss calculation. + + Parameters + ---------- + tensors + Input tensors. + inference_outputs + Inference outputs. + generative_outputs + Generative outputs. + kl_weight + KL weight. Default is 1.0. + + Returns + ------- + Prediction loss. + """ loss, recon_loss, kl_loss, extra_metrics = self._calculate_loss( tensors, inference_outputs, generative_outputs, kl_weight ) @@ -258,6 +350,12 @@ def loss(self, tensors, inference_outputs, generative_outputs, kl_weight: float ) def select_losses_to_plot(self): + """Select losses to plot. + + Returns + ------- + Loss names. + """ loss_names = [] if self.class_loss_coef != 0 and len(self.class_idx) > 0: loss_names.extend(["class_loss", "accuracy"]) @@ -266,7 +364,3 @@ def select_losses_to_plot(self): if self.regression_loss_coef != 0 and len(self.ord_idx) > 0: loss_names.extend(["regression_loss", "accuracy"]) return loss_names - - @torch.inference_mode() - def sample(self, tensors, n_samples=1): - return self.vae.sample(tensors, n_samples) diff --git a/src/multimil/module/_multivae_mil_torch.py b/src/multimil/module/_multivae_mil_torch.py index 4811d8b..3572977 100644 --- a/src/multimil/module/_multivae_mil_torch.py +++ b/src/multimil/module/_multivae_mil_torch.py @@ -1,3 +1,4 @@ +import torch from scvi import REGISTRY_KEYS from scvi.module.base import BaseModuleClass, LossOutput, auto_move_data @@ -5,6 +6,100 @@ class MultiVAETorch_MIL(BaseModuleClass): + """MultiMIL's end-to-end multimodal integration and MIL classification modules. + + Parameters + ---------- + modality_lengths + Number of features for each modality. + condition_encoders + Whether to condition the encoders on the covariates. + condition_decoders + Whether to condition the decoders on the covariates. + normalization + Normalization to use in the network. + z_dim + Dimensionality of the latent space. + losses + List of losses to use in the VAE. + dropout + Dropout rate. + cond_dim + Dimensionality of the covariate embeddings. + kernel_type + Type of kernel to use for the MMD loss. + loss_coefs + Coefficients for the different losses. + num_groups + Number of groups to use for the MMD loss. + integrate_on_idx + Indices of the covariates to integrate on. + n_layers_encoders + Number of layers in the encoders. + n_layers_decoders + Number of layers in the decoders. + n_hidden_encoders + Number of hidden units in the encoders. + n_hidden_decoders + Number of hidden units in the decoders. + num_classification_classes + Number of classes for each of the classification task. + scoring + Scoring function to use for the MIL classification. + attn_dim + Dimensionality of the hidden attention dimension. + cat_covariate_dims + Number of categories for each of the categorical covariates. + cont_covariate_dims + Number of categories for each of the continuous covariates. Always 1. + cat_covs_idx + Indices of the categorical covariates. + cont_covs_idx + Indices of the continuous covariates. + cont_cov_type + Type of continuous covariate. + n_layers_cell_aggregator + Number of layers in the cell aggregator. + n_layers_classifier + Number of layers in the classifier. + n_layers_mlp_attn + Number of layers in the attention MLP. + n_layers_cont_embed + Number of layers in the continuous embedding calculation. + n_layers_regressor + Number of layers in the regressor. + n_hidden_regressor + Number of hidden units in the regressor. + n_hidden_cell_aggregator + Number of hidden units in the cell aggregator. + n_hidden_classifier + Number of hidden units in the classifier. + n_hidden_mlp_attn + Number of hidden units in the attention MLP. + n_hidden_cont_embed + Number of hidden units in the continuous embedding calculation. + class_loss_coef + Coefficient for the classification loss. + regression_loss_coef + Coefficient for the regression loss. + sample_batch_size + Bag size. + class_idx + Which indices in cat covariates to do classification on. + ord_idx + Which indices in cat covariates to do ordinal regression on. + reg_idx + Which indices in cont covariates to do regression on. + mmd + Type of MMD loss to use. + activation + Activation function to use. + initialization + Initialization method to use. + anneal_class_loss + Whether to anneal the classification loss. + """ + def __init__( self, modality_lengths, @@ -131,7 +226,22 @@ def _get_generative_input(self, tensors, inference_outputs): return {"z_joint": z_joint, "cat_covs": cat_covs, "cont_covs": cont_covs} @auto_move_data - def inference(self, x, cat_covs, cont_covs): + def inference(self, x, cat_covs, cont_covs) -> dict[str, torch.Tensor | list[torch.Tensor]]: + """Forward pass for inference. + + Parameters + ---------- + x + Input. + cat_covs + Categorical covariates to condition on. + cont_covs + Continuous covariates to condition on. + + Returns + ------- + Joint representations, marginal representations, joint mu's and logvar's and predictions. + """ # VAE part inference_outputs = self.vae_module.inference(x, cat_covs, cont_covs) z_joint = inference_outputs["z_joint"] @@ -142,10 +252,42 @@ def inference(self, x, cat_covs, cont_covs): return inference_outputs # z_joint, mu, logvar, z_marginal, predictions @auto_move_data - def generative(self, z_joint, cat_covs, cont_covs): + def generative(self, z_joint, cat_covs, cont_covs) -> dict[str, torch.Tensor]: + """Compute necessary inference quantities. + + Parameters + ---------- + z_joint + Tensor of values with shape ``(batch_size, z_dim)``. + cat_covs + Categorical covariates to condition on. + cont_covs + Continuous covariates to condition on. + + Returns + ------- + Reconstructed values for each modality. + """ return self.vae_module.generative(z_joint, cat_covs, cont_covs) def loss(self, tensors, inference_outputs, generative_outputs, kl_weight: float = 1.0): + """Calculate the (modality) reconstruction loss, Kullback divergences and integration loss. + + Parameters + ---------- + tensors + Tensor of values with shape ``(batch_size, n_input_features)``. + inference_outputs + Dictionary with the inference output. + generative_outputs + Dictionary with the generative output. + kl_weight + Weight of the KL loss. Default is 1.0. + + Returns + ------- + Reconstruction loss, Kullback divergences, integration loss, modality reconstruction and prediction losses. + """ loss_vae, recon_loss, kl_loss, extra_metrics = self.vae_module._calculate_loss( tensors, inference_outputs, generative_outputs, kl_weight ) diff --git a/src/multimil/module/_multivae_torch.py b/src/multimil/module/_multivae_torch.py index 527f94d..c22d3c8 100644 --- a/src/multimil/module/_multivae_torch.py +++ b/src/multimil/module/_multivae_torch.py @@ -14,66 +14,68 @@ class MultiVAETorch(BaseModuleClass): - """The multigrate model implemented following scvi-tools module sctructure. - - :param modality_lengths: - List with lengths of each modality - :param condition_encoders: - Boolean to indicate if to condition encoders - :param condition_decoders: - Boolean to indicate if to condition decoders - :param normalization: + """MultiMIL's multimodal integration module. + + Parameters + ---------- + modality_lengths + List with lengths of each modality. + condition_encoders + Boolean to indicate if to condition encoders. + condition_decoders + Boolean to indicate if to condition decoders. + normalization One of the following * ``'layer'`` - layer normalization * ``'batch'`` - batch normalization - * ``None`` - no normalization - :param z_dim: - Dimensionality of the latent space - :param losses: + * ``None`` - no normalization. + z_dim + Dimensionality of the latent space. + losses List of which losses to use. For each modality can be one of the following: * ``'mse'`` - mean squared error * ``'nb'`` - negative binomial * ``zinb`` - zero-inflated negative binomial - * ``bce`` - binary cross-entropy - :param dropout: - Dropout rate for neural networks - :param cond_dim: - Dimensionality of the covariate embeddings - :param kernel_type: + * ``bce`` - binary cross-entropy. + dropout + Dropout rate for neural networks. + cond_dim + Dimensionality of the covariate embeddings. + kernel_type One of the following: * ``'gaussian'`` - Gaussian kernel - * ``'not gaussian'`` - not Gaussian kernel - :param loss_coefs: - Dictionary with weights for each of the losses - :param num_groups: - Number of groups to integrate on - :param integrate_on_idx: - Indices on which to integrate on - :param cat_covariate_dims: - List with number of classes for each of the categorical covariates - :param cont_covariate_dims: - List of 1's for each of the continuous covariate - :param cont_cov_type: + * ``'not gaussian'`` - not Gaussian kernel. + loss_coefs + Dictionary with weights for each of the losses. + num_groups + Number of groups to integrate on. + integrate_on_idx + Indices on which to integrate on. + cat_covariate_dims + List with number of classes for each of the categorical covariates. + cont_covariate_dims + List of 1's for each of the continuous covariate. + cont_cov_type How to transform continuous covariate before multiplying with the embedding. One of the following: * ``'logsim'`` - generalized sigmoid - * ``'mlp'`` - MLP - :param n_layers_cont_embed: - Number of layers for the transformation of the continuous covariates before multiplying with the embedding - :param n_hidden_cont_embed: - Number of nodes in hidden layers in the network that transforms continuous covariates - :param n_layers_encoders: - Number of layers in each encoder - :param n_layers_decoders: - Number of layers in each decoder - :param n_hidden_encoders: - Number of nodes in hidden layers in encoders - :param n_hidden_decoders: - Number of nodes in hidden layers in decoders - :param mmd: + * ``'mlp'`` - MLP. + n_layers_cont_embed + Number of layers for the transformation of the continuous covariates before multiplying with the embedding. + n_hidden_cont_embed + Number of nodes in hidden layers in the network that transforms continuous covariates. + n_layers_encoders + Number of layers in each encoder. + n_layers_decoders + Number of layers in each decoder. + n_hidden_encoders + Number of nodes in hidden layers in encoders. + n_hidden_decoders + Number of nodes in hidden layers in decoders. + mmd How to calculate MMD loss. One of the following * ``'latent'`` - only on the latent representations * ``'marginal'`` - only on the marginal representations - * ``both`` - the sum of the two above + * ``both`` - the sum of the two above. """ def __init__( @@ -343,16 +345,20 @@ def inference( ) -> dict[str, torch.Tensor | list[torch.Tensor]]: """Compute necessary inference quantities. - :param x: - Tensor of values with shape ``(batch_size, n_input_features)`` - :param cat_covs: - Categorical covariates to condition on - :param cont_covs: - Continuous covariates to condition on - :param masks: - List of binary tensors indicating which values in ``x`` belong to which modality - :returns: - Joint representations, marginal representations, joint mu's and logvar's. + Parameters + ---------- + x + Tensor of values with shape ``(batch_size, n_input_features)``. + cat_covs + Categorical covariates to condition on. + cont_covs + Continuous covariates to condition on. + masks + List of binary tensors indicating which values in ``x`` belong to which modality. + + Returns + ------- + Joint representations, marginal representations, joint mu's and logvar's. """ # split x into modality xs if torch.is_tensor(x): @@ -362,6 +368,7 @@ def inference( else: xs = x + # TODO: check if masks still supported if masks is None: masks = [x.sum(dim=1) > 0 for x in xs] # list of masks per modality masks = torch.stack(masks, dim=1) @@ -401,14 +408,18 @@ def generative( ) -> dict[str, list[torch.Tensor]]: """Compute necessary inference quantities. - :param z_joint: - Tensor of values with shape ``(batch_size, z_dim)`` - :param cat_covs: - Categorical covariates to condition on - :param cont_covs: - Continuous covariates to condition on - :returns: - Reconstructed values for each modality. + Parameters + ---------- + z_joint + Tensor of values with shape ``(batch_size, z_dim)``. + cat_covs + Categorical covariates to condition on. + cont_covs + Continuous covariates to condition on. + + Returns + ------- + Reconstructed values for each modality. """ z = z_joint.unsqueeze(1).repeat(1, self.n_modality, 1) zs = torch.split(z, 1, dim=1) @@ -534,16 +545,20 @@ def loss( ]: """Calculate the (modality) reconstruction loss, Kullback divergences and integration loss. - :param tensors: - Tensor of values with shape ``(batch_size, n_input_features)`` - :param inference_outputs: - Dictionary with the inference output - :param generative_outputs: - Dictionary with the generative output - :param kl_weight: - Weight of the KL loss - :returns: - Reconstruction loss, Kullback divergences, integration loss and modality reconstruction losses. + Parameters + ---------- + tensors + Tensor of values with shape ``(batch_size, n_input_features)``. + inference_outputs + Dictionary with the inference output. + generative_outputs + Dictionary with the generative output. + kl_weight + Weight of the KL loss. Default is 1.0. + + Returns + ------- + Reconstruction loss, Kullback divergences, integration loss and modality reconstruction losses. """ loss, recon_loss, kl_loss, extra_metrics = self._calculate_loss( tensors, inference_outputs, generative_outputs, kl_weight @@ -556,20 +571,32 @@ def loss( extra_metrics=extra_metrics, ) - @torch.inference_mode() - def sample(self, tensors, n_samples=1): - """Sample from the generative model.""" - inference_kwargs = {"n_samples": n_samples} - with torch.inference_mode(): - ( - _, - generative_outputs, - ) = self.forward( - tensors, - inference_kwargs=inference_kwargs, - compute_loss=False, - ) - return generative_outputs["rs"] + # @torch.inference_mode() + # def sample(self, tensors, n_samples=1): + # """Sample from the generative model. + + # Parameters + # ---------- + # tensors + # Tensor of values. + # n_samples + # Number of samples to generate. + + # Returns + # ------- + # Generative outputs. + # """ + # inference_kwargs = {"n_samples": n_samples} + # with torch.inference_mode(): + # ( + # _, + # generative_outputs, + # ) = self.forward( + # tensors, + # inference_kwargs=inference_kwargs, + # compute_loss=False, + # ) + # return generative_outputs["rs"] def _calc_recon_loss(self, xs, rs, losses, group, size_factor, loss_coefs, masks): loss = [] @@ -643,6 +670,12 @@ def _compute_cont_cov_embeddings(self, covs): return self.cont_covariate_curves(covs) @ self.cont_covariate_embeddings.weight def select_losses_to_plot(self): + """Select losses to plot. + + Returns + ------- + Loss names. + """ loss_names = ["kl_local", "elbo", "reconstruction_loss"] for i in range(self.n_modality): loss_names.append(f"modality_{i}_reconstruction_loss") diff --git a/src/multimil/nn/_base_components.py b/src/multimil/nn/_base_components.py index 43d289a..7965ba1 100644 --- a/src/multimil/nn/_base_components.py +++ b/src/multimil/nn/_base_components.py @@ -7,7 +7,26 @@ class MLP(nn.Module): - """A helper class to build blocks of fully-connected, normalization, dropout and activation layers.""" + """A helper class to build blocks of fully-connected, normalization, dropout and activation layers. + + Parameters + ---------- + n_input + Number of input features. + n_output + Number of output features. + n_layers + Number of hidden layers. + n_hidden + Number of hidden units. + dropout_rate + Dropout rate. + normalization + Type of normalization to use. Can be one of ["layer", "batch", "none"]. + activation + Activation function to use. + + """ def __init__( self, @@ -42,16 +61,40 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward computation on ``x``. - :param x: - tensor of values with shape ``(n_in,)`` - :returns: - tensor of values with shape ``(n_out,)`` + Parameters + ---------- + x + Tensor of values with shape ``(n_input,)``. + + Returns + ------- + Tensor of values with shape ``(n_output,)``. """ return self.mlp(x) class Decoder(nn.Module): - """A helper class to build custom decoders depending on which loss was passed.""" + """A helper class to build custom decoders depending on which loss was passed. + + Parameters + ---------- + n_input + Number of input features. + n_output + Number of output features. + n_layers + Number of hidden layers. + n_hidden + Number of hidden units. + dropout_rate + Dropout rate. + normalization + Type of normalization to use. Can be one of ["layer", "batch", "none"]. + activation + Activation function to use. + loss + Loss function to use. Can be one of ["mse", "nb", "zinb", "bce"]. + """ def __init__( self, @@ -103,10 +146,14 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward computation on ``x``. - :param x: - tensor of values with shape ``(n_in,)`` - :returns: - tensor of values with shape ``(n_out,)`` + Parameters + ---------- + x + Tensor of values with shape ``(n_input,)``. + + Returns + ------- + Tensor of values with shape ``(n_output,)``. """ x = self.decoder(x) if self.loss == "mse" or self.loss == "bce": @@ -125,6 +172,13 @@ class GeneralizedSigmoid(nn.Module): Date: 26.01.2022 Link to the used code: https://github.com/facebookresearch/CPA/blob/382ff641c588820a453d801e5d0e5bb56642f282/compert/model.py#L109 + + Parameters + ---------- + dim + Number of input features. + nonlin + Type of non-linearity to use. Can be one of ["logsigm", "sigm"]. Default is "logsigm". """ def __init__(self, dim, nonlin: Literal["logsigm", "sigm"] | None = "logsigm"): @@ -133,8 +187,18 @@ def __init__(self, dim, nonlin: Literal["logsigm", "sigm"] | None = "logsigm"): self.beta = torch.nn.Parameter(torch.ones(1, dim), requires_grad=True) self.bias = torch.nn.Parameter(torch.zeros(1, dim), requires_grad=True) - def forward(self, x): - """Forward computation on ``x``.""" + def forward(self, x) -> torch.Tensor: + """Forward computation on ``x``. + + Parameters + ---------- + x + Tensor of values. + + Returns + ------- + Tensor of values with the same shape as ``x``. + """ if self.nonlin == "logsigm": return (torch.log1p(x) * self.beta + self.bias).sigmoid() elif self.nonlin == "sigm": @@ -144,7 +208,30 @@ def forward(self, x): class Aggregator(nn.Module): - # TODO add docstring + """A helper class to build custom aggregators depending on the scoring function passed. + + Parameters + ---------- + n_input + Number of input features. + scoring + Scoring function to use. Can be one of ["attn", "gated_attn", "mlp"]. + attn_dim + Dimension of the hidden attention layer. + sample_batch_size + Bag batch size. + scale + Whether to scale the attention weights. + dropout + Dropout rate. + n_layers_mlp_attn + Number of hidden layers in the MLP attention. + n_hidden_mlp_attn + Number of hidden units in the MLP attention. + activation + Activation function to use. + """ + def __init__( self, n_input=None, @@ -200,10 +287,19 @@ def __init__( nn.Linear(n_hidden_mlp_attn, 1), ) - def forward(self, x): - # if self.scoring == "sum": - # return torch.sum(x, dim=0) # z_dim depricated + def forward(self, x) -> torch.Tensor: + """Forward computation on ``x``. + + Parameters + ---------- + x + Tensor of values with shape ``(n_input,)``. + Returns + ------- + Tensor of pooled values. + """ + # TODO add sum, mean and max pooling if self.scoring == "attn": # from https://github.com/AMLab-Amsterdam/AttentionDeepMIL/blob/master/model.py (accessed 16.09.2021) self.A = self.attention(x) # Nx1 @@ -231,4 +327,4 @@ def forward(self, x): if self.scale: self.A = self.A * self.A.shape[-1] / self.patient_batch_size - return torch.bmm(self.A, x).squeeze(dim=1) # z_dim + return torch.bmm(self.A, x).squeeze(dim=1) diff --git a/src/multimil/utils/_utils.py b/src/multimil/utils/_utils.py index 795820d..e66c0e5 100644 --- a/src/multimil/utils/_utils.py +++ b/src/multimil/utils/_utils.py @@ -7,7 +7,22 @@ from matplotlib import pyplot as plt -def create_df(pred, columns=None, index=None): +def create_df(pred, columns=None, index=None) -> pd.DataFrame: + """Create a pandas DataFrame from a list of predictions. + + Parameters + ---------- + pred + List of predictions. + columns + Column names, i.e. class_names. + index + Index names, i.e. obs_names. + + Returns + ------- + DataFrame with predictions. + """ if isinstance(pred, dict): for key in pred.keys(): pred[key] = torch.cat(pred[key]).squeeze().cpu().numpy() @@ -22,7 +37,22 @@ def create_df(pred, columns=None, index=None): return df -def calculate_size_factor(adata, size_factor_key, rna_indices_end): +def calculate_size_factor(adata, size_factor_key, rna_indices_end) -> str: + """Calculate size factors. + + Parameters + ---------- + adata + Annotated data object. + size_factor_key + Key in `adata.obs` where size factors are stored. + rna_indices_end + Index of the last RNA feature in the data. + + Returns + ------- + Size factor key. + """ # TODO check that organize_multiome_anndatas was run, i.e. that .uns['modality_lengths'] was added, needed for q2r if size_factor_key is not None and rna_indices_end is not None: raise ValueError( @@ -44,6 +74,17 @@ def calculate_size_factor(adata, size_factor_key, rna_indices_end): def setup_ordinal_regression(adata, ordinal_regression_order, categorical_covariate_keys): + """Setup ordinal regression. + + Parameters + ---------- + adata + Annotated data object. + ordinal_regression_order + Order of categories for ordinal regression. + categorical_covariate_keys + Keys of categorical covariates. + """ # TODO make sure not to assume categorical columns for ordinal regression -> change to np.unique if needed if ordinal_regression_order is not None: if not set(ordinal_regression_order.keys()).issubset(categorical_covariate_keys): @@ -59,7 +100,22 @@ def setup_ordinal_regression(adata, ordinal_regression_order, categorical_covari adata.obs[key] = adata.obs[key].cat.reorder_categories(ordinal_regression_order[key]) -def select_covariates(covs, prediction_idx, n_samples_in_batch): +def select_covariates(covs, prediction_idx, n_samples_in_batch) -> torch.Tensor: + """Select prediction covariates from all covariates. + + Parameters + ---------- + covs + Covariates. + prediction_idx + Index of predictions. + n_samples_in_batch + Number of samples in the batch. + + Returns + ------- + Prediction covariates. + """ if len(prediction_idx) > 0: covs = torch.index_select(covs, 1, prediction_idx) covs = covs.view(n_samples_in_batch, -1, len(prediction_idx))[:, 0, :] @@ -68,7 +124,20 @@ def select_covariates(covs, prediction_idx, n_samples_in_batch): return covs -def prep_minibatch(covs, sample_batch_size): +def prep_minibatch(covs, sample_batch_size) -> tuple[int, int]: + """Prepare minibatch. + + Parameters + ---------- + covs + Covariates. + sample_batch_size + Sample batch size. + + Returns + ------- + Batch size and number of samples in the batch. + """ batch_size = covs.shape[0] if batch_size % sample_batch_size != 0: @@ -78,7 +147,34 @@ def prep_minibatch(covs, sample_batch_size): return batch_size, n_samples_in_batch -def get_predictions(prediction_idx, pred_values, true_values, size, bag_pred, bag_true, full_pred, offset=0): +def get_predictions( + prediction_idx, pred_values, true_values, size, bag_pred, bag_true, full_pred, offset=0 +) -> tuple[dict, dict, dict]: + """Get predictions. + + Parameters + ---------- + prediction_idx + Index of predictions. + pred_values + Predicted values. + true_values + True values. + size + Size of the bag minibatch. + bag_pred + Bag predictions. + bag_true + Bag true values. + full_pred + Full predictions, i.e. on cell-level. + offset + Offset, needed because of several possible types of predictions. + + Returns + ------- + Bag predictions, bag true values, full predictions on cell-level. + """ for i in range(len(prediction_idx)): bag_pred[i] = bag_pred.get(i, []) + [pred_values[offset + i].cpu()] bag_true[i] = bag_true.get(i, []) + [true_values[:, i].cpu()] @@ -90,6 +186,27 @@ def get_predictions(prediction_idx, pred_values, true_values, size, bag_pred, ba def get_bag_info(bags, n_samples_in_batch, minibatch_size, cell_counter, bag_counter, sample_batch_size): + """Get bag information. + + Parameters + ---------- + bags + Bags. + n_samples_in_batch + Number of samples in the batch. + minibatch_size + Minibatch size. + cell_counter + Cell counter. + bag_counter + Bag counter. + sample_batch_size + Sample batch size. + + Returns + ------- + Updated bags, cell counter, and bag counter. + """ if n_samples_in_batch == 1: bags += [[bag_counter] * minibatch_size] cell_counter += minibatch_size @@ -104,6 +221,31 @@ def get_bag_info(bags, n_samples_in_batch, minibatch_size, cell_counter, bag_cou def save_predictions_in_adata( adata, idx, predictions, bag_pred, bag_true, cell_pred, class_names, name, clip, reg=False ): + """Save predictions in anndata object. + + Parameters + ---------- + adata + Annotated data object. + idx + Index, i.e. obs_names. + predictions + Predictions. + bag_pred + Bag predictions. + bag_true + Bag true values. + cell_pred + Cell predictions. + class_names + Class names. + name + Name of the prediction column. + clip + Whether to transofrm the predictions. One of `clip`, `argmax`, or `none`. + reg + Whether the rediciton task is a regression task. + """ # cell level predictions df = create_df(cell_pred[idx], class_names, index=adata.obs_names) adata.obsm[f"full_predictions_{name}"] = df @@ -131,6 +273,17 @@ def save_predictions_in_adata( def plt_plot_losses(history, loss_names, save): + """Plot losses. + + Parameters + ---------- + history + History of losses. + loss_names + Loss names to plot. + save + Path to save the plot. + """ df = pd.concat(history, axis=1) df.columns = df.columns.droplevel(-1) df["epoch"] = df.index From df8b660fc01205badc873a26c66fcf42178290a2 Mon Sep 17 00:00:00 2001 From: alitinet Date: Sun, 21 Jul 2024 19:23:42 +0200 Subject: [PATCH 2/2] removed fail on warnings flag --- .readthedocs.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 69897c3..23a5340 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -7,7 +7,7 @@ build: sphinx: configuration: docs/conf.py # disable this for more lenient docs builds - fail_on_warning: true + fail_on_warning: false python: install: - method: pip