Skip to content

Commit

Permalink
Merge pull request #69 from theislab/fix_docs
Browse files Browse the repository at this point in the history
Fix docs
  • Loading branch information
alitinet authored Jul 21, 2024
2 parents aab0342 + df8b660 commit 8830bf2
Show file tree
Hide file tree
Showing 13 changed files with 1,048 additions and 413 deletions.
2 changes: 1 addition & 1 deletion .readthedocs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ build:
sphinx:
configuration: docs/conf.py
# disable this for more lenient docs builds
fail_on_warning: true
fail_on_warning: false
python:
install:
- method: pip
Expand Down
13 changes: 9 additions & 4 deletions src/multimil/data/_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,24 @@
def organize_multiome_anndatas(
adatas: list[list[ad.AnnData | None]],
layers: list[list[str | None]] | None = None,
):
) -> ad.AnnData:
"""Concatenate all the input anndata objects.
These anndata objects should already have been preprocessed so that all single-modality
objects use a subset of the features used in the multiome object. The feature names (index of
`.var`) should match between the objects for vertical integration and cell names (index of
`.obs`) should match between the objects for horizontal integration.
:param adatas:
List of Lists with AnnData objects or None where each sublist corresponds to a modality
:param layers:
Parameters
----------
adatas
List of Lists with AnnData objects or None where each sublist corresponds to a modality.
layers
List of Lists of the same lengths as `adatas` specifying which `.layer` to use for each AnnData. Default is None which means using `.X`.
Returns
-------
Concatenated AnnData object across modalities and datasets.
"""
# TODO: add checks for layers
# TODO: add check that len of modalities is the same as len of losses, etc
Expand Down
47 changes: 25 additions & 22 deletions src/multimil/dataloaders/_ann_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,18 @@
class StratifiedSampler(torch.utils.data.sampler.Sampler):
"""Custom stratified sampler class which enables sampling the same number of observation from each group in each mini-batch.
:param indices:
list of indices to sample from
:param batch_size:
batch size of each iteration
:param shuffle:
if ``True``, shuffles indices before sampling
:param drop_last:
if int, drops the last batch if its length is less than drop_last.
if drop_last == True, drops last non-full batch.
if drop_last == False, iterate over all batches.
Parameters
----------
indices
List of indices to sample from.
batch_size
Batch size of each iteration.
shuffle
If ``True``, shuffles indices before sampling.
drop_last
If int, drops the last batch if its length is less than drop_last.
If drop_last == True, drops last non-full batch.
If drop_last == False, iterate over all batches.
"""

def __init__(
Expand Down Expand Up @@ -139,23 +141,24 @@ def __len__(self):
# https://github.com/scverse/scvi-tools/blob/0b802762869c43c9f49e69fe62b1a5a9b5c4dae6/scvi/dataloaders/_ann_dataloader.py#L89
# accessed on 5 November 2022
class GroupAnnDataLoader(DataLoader):
"""
DataLoader for loading tensors from AnnData objects.
"""DataLoader for loading tensors from AnnData objects.
:param adata_manager:
Parameters
----------
adata_manager
:class:`~scvi.data.AnnDataManager` object with a registered AnnData object.
:param shuffle:
Whether the data should be shuffled
:param indices:
The indices of the observations in the adata to load
:param batch_size:
minibatch size to load each iteration
:param data_and_attributes:
shuffle
Whether the data should be shuffled.
indices
The indices of the observations in the adata to load.
batch_size
Minibatch size to load each iteration.
data_and_attributes
Dictionary with keys representing keys in data registry (`adata.uns["_scvi"]`)
and value equal to desired numpy loading type (later made into torch tensor).
If `None`, defaults to all registered data.
:param data_loader_kwargs:
Keyword arguments for :class:`~torch.utils.data.DataLoader`
data_loader_kwargs
Keyword arguments for :class:`~torch.utils.data.DataLoader`.
"""

def __init__(
Expand Down
19 changes: 10 additions & 9 deletions src/multimil/dataloaders/_data_splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,22 @@
# https://github.com/scverse/scvi-tools/blob/0b802762869c43c9f49e69fe62b1a5a9b5c4dae6/scvi/dataloaders/_data_splitting.py#L56
# accessed on 5 November 2022
class GroupDataSplitter(DataSplitter):
"""
Creates data loaders ``train_set``, ``validation_set``, ``test_set``.
"""Creates data loaders ``train_set``, ``validation_set``, ``test_set``.
If ``train_size + validation_set < 1`` then ``test_set`` is non-empty.
:param adata_manager:
Parameters
----------
adata_manager
:class:`~scvi.data.AnnDataManager` object that has been created via ``setup_anndata``.
:param train_size:
float, or None (default is 0.9)
:param validation_size:
float, or None (default is None)
:param use_gpu:
train_size
Proportion of cells to use as the train set. Float, or None (default is 0.9).
validation_size
Proportion of cell to use as the valisation set. Float, or None (default is None). If None, is set to 1 - ``train_size``.
use_gpu
Use default GPU if available (if None or True), or index of GPU to use (if int),
or name of GPU (if str, e.g., `'cuda:0'`), or use CPU (if False).
:param kwargs:
kwargs
Keyword args for data loader. Data loader class is :class:`~mtg.dataloaders.GroupAnnDataLoader`.
"""

Expand Down
40 changes: 26 additions & 14 deletions src/multimil/distributions/_mmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
class MMD(torch.nn.Module):
"""Maximum mean discrepancy.
:param kernel_type:
Parameters
----------
kernel_type
Indicates if to use Gaussian kernel. One of
* ``'gaussian'`` - use Gaussian kernel
* ``'not gaussian'`` - do not use Gaussian kernel
* ``'not gaussian'`` - do not use Gaussian kernel.
"""

def __init__(self, kernel_type="gaussian"):
Expand All @@ -23,12 +25,18 @@ def gaussian_kernel(
) -> torch.Tensor:
"""Apply Guassian kernel.
:param x:
Tensor from the first distribution
:param y:
Tensor from the second distribution
:param gamma:
List of gamma parameters
Parameters
----------
x
Tensor from the first distribution.
y
Tensor from the second distribution.
gamma
List of gamma parameters.
Returns
-------
Gaussian kernel between ``x`` and ``y``.
"""
if gamma is None:
gamma = [
Expand Down Expand Up @@ -71,12 +79,16 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
Availability: https://github.com/theislab/scarches/blob/63a7c2b35a01e55fe7e1dd871add459a86cd27fb/scarches/models/trvae/losses.py
Citation: Gretton, Arthur, et al. "A Kernel Two-Sample Test", 2012.
:param x:
Tensor with shape ``(batch_size, z_dim)``
:param y:
Tensor with shape ``(batch_size, z_dim)``
:returns:
MMD between ``x`` and ``y``
Parameters
----------
x
Tensor with shape ``(batch_size, z_dim)``.
y
Tensor with shape ``(batch_size, z_dim)``.
Returns
-------
MMD between ``x`` and ``y``.
"""
# in case there is only one sample in a batch belonging to one of the groups, then skip the batch
if len(x) == 1 or len(y) == 1:
Expand Down
Loading

0 comments on commit 8830bf2

Please sign in to comment.