Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(breadbox): Service layer proposal #93

Draft
wants to merge 7 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions breadbox/breadbox/api/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from breadbox.crud.access_control import PUBLIC_GROUP_ID
from ..crud import dataset as dataset_crud
from ..crud import types as type_crud
from ..crud import slice as slice_crud

from ..models.dataset import (
Dataset as DatasetModel,
Expand All @@ -54,6 +53,12 @@
DimensionDataResponse,
SliceQueryIdentifierType,
)
from breadbox.service.data_loading import get_subsetted_matrix_dataset_df
from breadbox.service.labels import (
get_dataset_feature_labels_by_id,
get_dataset_sample_labels_by_id,
)
from breadbox.service.slice import get_slice_data, get_labels_for_slice_type
from .dependencies import get_dataset as get_dataset_dep
from .dependencies import get_db_with_user, get_user

Expand Down Expand Up @@ -110,7 +115,7 @@ def get_dataset_features(
if dataset is None:
raise HTTPException(404, "Dataset not found")

feature_labels_by_id = dataset_crud.get_dataset_feature_labels_by_id(
feature_labels_by_id = get_dataset_feature_labels_by_id(
db=db, user=user, dataset=dataset,
)
return [{"id": id, "label": label} for id, label in feature_labels_by_id.items()]
Expand All @@ -133,7 +138,7 @@ def get_dataset_samples(
if dataset is None:
raise HTTPException(404, "Dataset not found")

sample_labels_by_id = dataset_crud.get_dataset_sample_labels_by_id(
sample_labels_by_id = get_dataset_sample_labels_by_id(
db=db, user=user, dataset=dataset,
)
return [{"id": id, "label": label} for id, label in sample_labels_by_id.items()]
Expand Down Expand Up @@ -185,7 +190,7 @@ def get_feature_data(
# Get the feature label
if dataset.feature_type_name:
# Note: this would be faster if we had a query to load one label instead of all labels - but performance hasn't been an issue
feature_labels_by_id = dataset_crud.get_dataset_feature_labels_by_id(
feature_labels_by_id = get_dataset_feature_labels_by_id(
db=db, user=user, dataset=dataset
)
label = feature_labels_by_id[feature.given_id]
Expand Down Expand Up @@ -323,7 +328,7 @@ def get_matrix_dataset_data(
] = False,
):
try:
df = dataset_crud.get_subsetted_matrix_dataset_df(
df = get_subsetted_matrix_dataset_df(
db,
user,
dataset,
Expand Down Expand Up @@ -404,7 +409,7 @@ def get_dataset_data(
except UserError as e:
raise e

df = dataset_crud.get_subsetted_matrix_dataset_df(
df = get_subsetted_matrix_dataset_df(
db, user, dataset, dim_info, settings.filestore_location
)

Expand Down Expand Up @@ -490,13 +495,13 @@ def get_dimension_data(
identifier=identifier,
identifier_type=identifier_type.name,
)
slice_values_by_id = slice_crud.get_slice_data(
slice_values_by_id = get_slice_data(
db, settings.filestore_location, parsed_slice_query
)

# Load the labels separately, ensuring they're in the same order as the other values
slice_ids: list = slice_values_by_id.index.tolist()
labels_by_id = slice_crud.get_labels_for_slice_type(db, parsed_slice_query)
labels_by_id = get_labels_for_slice_type(db, parsed_slice_query)
slice_labels = [labels_by_id[id] for id in slice_ids] if labels_by_id else slice_ids

return {
Expand Down
13 changes: 7 additions & 6 deletions breadbox/breadbox/compute/analysis_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,11 @@
ValueType,
)
from breadbox.schemas.custom_http_exception import ResourceNotFoundError, UserError

from breadbox.schemas.dataset import MatrixDatasetIn
from breadbox.service.labels import (
get_dataset_feature_labels_by_id,
get_dataset_feature_by_label,
)

from ..crud.types import get_dimension_type
from ..crud import dataset as dataset_crud
Expand Down Expand Up @@ -170,9 +173,7 @@ def get_features_info_and_dataset(
result_features: List[Feature] = []
dataset_feature_ids: List[str] = []
datasets: List[Dataset] = []
feature_labels_by_id = dataset_crud.get_dataset_feature_labels_by_id(
db, user, dataset
)
feature_labels_by_id = get_dataset_feature_labels_by_id(db, user, dataset)
feature_indices = []

for dataset_feat in dataset_features:
Expand Down Expand Up @@ -493,12 +494,12 @@ def create_cell_line_group(

# Return the feature ID associated with the new dataset feature
if use_feature_ids:
feature: DatasetFeature = dataset_crud.get_dataset_feature_by_label(
feature: DatasetFeature = get_dataset_feature_by_label(
db=db, dataset_id=dataset_id, feature_label=feature_label
)
return _format_breadbox_shim_slice_id(feature.dataset_id, feature.given_id)
else:
dataset_feature = dataset_crud.get_dataset_feature_by_label(
dataset_feature = get_dataset_feature_by_label(
db, dataset_id, feature_label
)
return str(dataset_feature.id)
2 changes: 1 addition & 1 deletion breadbox/breadbox/compute/download_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from breadbox.crud.dataset import get_sample_indexes_by_given_ids
from breadbox.crud.dataset import get_all_sample_indexes
from breadbox.crud.partial import get_cell_line_selector_lines
from breadbox.crud.dataset import get_dataset_feature_labels_by_id
from ..config import get_settings
from ..models.dataset import (
Dataset,
Expand All @@ -21,6 +20,7 @@
)
from .celery import app
from ..db.util import db_context
from breadbox.service.labels import get_dataset_feature_labels_by_id


def _progress_callback(task, percentage, message="Fetching data"):
Expand Down
188 changes: 10 additions & 178 deletions breadbox/breadbox/crud/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,8 @@
TRANSIENT_GROUP_ID,
get_transient_group,
)
from breadbox.io.filestore_crud import (
get_slice,
delete_data_files,
)
from breadbox.io.filestore_crud import delete_data_files
from breadbox.crud.dimension_type import get_dimension_type
from .metadata import cast_tabular_cell_value_type
from .dataset_reference import add_id_mapping
import typing
Expand Down Expand Up @@ -994,47 +992,12 @@ def get_dataset_samples(db: SessionWithUser, dataset: Dataset, user: str):
return dataset_samples


def get_dataset_feature_labels_by_id(
db: SessionWithUser, user: str, dataset: Dataset,
) -> dict[str, str]:
"""
Try loading feature labels from metadata.
If there are no labels in the metadata or there is no metadata, then just return the feature names.
"""
metadata_labels_by_given_id = get_dataset_feature_annotations(
db=db, user=user, dataset=dataset, metadata_col_name="label"
)

if metadata_labels_by_given_id:
return metadata_labels_by_given_id
else:
all_dataset_features = get_dataset_features(db=db, dataset=dataset, user=user)
return {feature.given_id: feature.given_id for feature in all_dataset_features}


def get_dataset_sample_labels_by_id(
db: SessionWithUser, user: str, dataset: Dataset,
) -> dict[str, str]:
"""
Try loading sample labels from metadata.
If there are no labels in the metadata or there is no metadata, then just return the sample names.
"""
metadata_labels = get_dataset_sample_annotations(
db=db, user=user, dataset=dataset, metadata_col_name="label"
)
if metadata_labels:
return metadata_labels
else:
samples = get_dataset_samples(db=db, dataset=dataset, user=user)
return {sample.given_id: sample.given_id for sample in samples}


from typing import Any


# TODO: This can probably be merged.
def get_dataset_feature_annotations(
db: SessionWithUser, user: str, dataset: Dataset, metadata_col_name: str,
db: SessionWithUser, user: str, dataset: MatrixDataset, metadata_col_name: str,
) -> dict[str, Any]:
"""
For the given dataset, load metadata of the specified type, keyed by feature id.
Expand All @@ -1047,16 +1010,9 @@ def get_dataset_feature_annotations(

# Try to find the associated metadata dataset
feature_metadata_dataset_id = None
if dataset.format == "matrix_dataset":
if dataset.feature_type is not None:
feature_type = (
db.query(DimensionType)
.filter(DimensionType.name == dataset.feature_type_name)
.one()
)
feature_metadata_dataset_id = feature_type.dataset_id
else:
feature_metadata_dataset_id = dataset.id
if dataset.feature_type is not None:
feature_type = get_dimension_type(db, dataset.feature_type_name)
feature_metadata_dataset_id = feature_type.dataset_id

data_dataset_feature = aliased(DatasetFeature)

Expand All @@ -1082,7 +1038,7 @@ def get_dataset_feature_annotations(


def get_dataset_sample_annotations(
db: SessionWithUser, user: str, dataset: Dataset, metadata_col_name: str
db: SessionWithUser, user: str, dataset: MatrixDataset, metadata_col_name: str
) -> dict[str, Any]:
"""
For the given dataset, load metadata of the specified type, keyed by sample id.
Expand All @@ -1095,16 +1051,9 @@ def get_dataset_sample_annotations(

# Try to find the associated metadata dataset
sample_metadata_dataset_id = None
if dataset.format == "matrix_dataset":
if dataset.sample_type is not None:
sample_type = (
db.query(DimensionType)
.filter(DimensionType.name == dataset.sample_type_name)
.one()
)
sample_metadata_dataset_id = sample_type.dataset_id
else:
sample_metadata_dataset_id = dataset.id
if dataset.sample_type is not None:
sample_type = get_dimension_type(db, dataset.sample_type_name)
sample_metadata_dataset_id = sample_type.dataset_id

data_dataset_sample = aliased(DatasetSample)

Expand Down Expand Up @@ -1426,50 +1375,6 @@ def get_dataset_sample_by_given_id(
return sample


def get_dataset_feature_by_label(
db: SessionWithUser, dataset_id: str, feature_label: str
) -> DatasetFeature:
"""Load the dataset feature corresponding to the given dataset ID and feature label"""

dataset = get_dataset(db, db.user, dataset_id)
if dataset is None:
raise ResourceNotFoundError(f"Dataset '{dataset_id}' not found.")
assert_user_has_access_to_dataset(dataset, db.user)
assert isinstance(dataset, MatrixDataset)

labels_by_given_id = get_dataset_feature_labels_by_id(db, db.user, dataset)
given_ids_by_label = {label: id for id, label in labels_by_given_id.items()}
feature_given_id = given_ids_by_label.get(feature_label)
if feature_given_id is None:
raise ResourceNotFoundError(
f"Feature label '{feature_label}' not found in dataset '{dataset_id}'."
)

return get_dataset_feature_by_given_id(db, dataset_id, feature_given_id)


def get_dataset_sample_by_label(
db: SessionWithUser, dataset_id: str, sample_label: str
) -> DatasetSample:
"""Load the dataset sample corresponding to the given dataset ID and sample label"""

dataset = get_dataset(db, db.user, dataset_id)
if dataset is None:
raise ResourceNotFoundError(f"Dataset '{dataset_id}' not found.")
assert_user_has_access_to_dataset(dataset, db.user)
assert isinstance(dataset, MatrixDataset)

labels_by_given_id = get_dataset_sample_labels_by_id(db, db.user, dataset)
given_ids_by_label = {label: id for id, label in labels_by_given_id.items()}
sample_given_id = given_ids_by_label.get(sample_label)
if sample_given_id is None:
raise ResourceNotFoundError(
f"Sample label '{sample_label}' not found in dataset '{dataset_id}'."
)

return get_dataset_sample_by_given_id(db, dataset_id, sample_given_id)


def _get_column_types(columns_metadata, columns: Optional[List[str]]):
col_and_column_metadata_pairs = columns_metadata.items()
if columns is None:
Expand Down Expand Up @@ -1660,76 +1565,3 @@ def get_missing_tabular_columns_and_indices(
)

return missing_columns, missing_indices


def get_subsetted_matrix_dataset_df(
db: SessionWithUser,
user: str,
dataset: Dataset,
dimensions_info: MatrixDimensionsInfo,
filestore_location,
strict: bool = False, # False default for backwards compatibility
):
"""
Load a dataframe containing data for the specified dimensions.
If the dimensions are specified by label, then return a result indexed by labels
"""

missing_features = []
missing_samples = []

if dimensions_info.features is None:
feature_indexes = None
elif dimensions_info.feature_identifier.value == "id":
feature_indexes, missing_features = get_feature_indexes_by_given_ids(
db, user, dataset, dimensions_info.features
)
else:
assert dimensions_info.feature_identifier.value == "label"
feature_indexes, missing_features = get_dimension_indexes_of_labels(
db, user, dataset, axis="feature", dimension_labels=dimensions_info.features
)

if len(missing_features) > 0:
log.warning(f"Could not find features: {missing_features}")

if dimensions_info.samples is None:
sample_indexes = None
elif dimensions_info.sample_identifier.value == "id":
sample_indexes, missing_samples = get_sample_indexes_by_given_ids(
db, user, dataset, dimensions_info.samples
)
else:
sample_indexes, missing_samples = get_dimension_indexes_of_labels(
db, user, dataset, axis="sample", dimension_labels=dimensions_info.samples
)

if len(missing_samples) > 0:
log.warning(f"Could not find samples: {missing_samples}")

if strict:
num_missing_features = len(missing_features)
missing_features_msg = f"{num_missing_features} missing features: {missing_features[:20] + ['...'] if num_missing_features >= 20 else missing_features}"
num_missing_samples = len(missing_samples)
missing_samples_msg = f"{num_missing_samples} missing samples: {missing_samples[:20] + ['...'] if num_missing_samples >= 20 else missing_samples}"
if len(missing_features) > 0 or len(missing_samples) > 0:
raise UserError(f"{missing_features_msg} and {missing_samples_msg}")

# call sort on the indices because hdf5_read requires indices be in ascending order
if feature_indexes is not None:
feature_indexes = sorted(feature_indexes)
if sample_indexes is not None:
sample_indexes = sorted(sample_indexes)

df = get_slice(dataset, feature_indexes, sample_indexes, filestore_location)

# Re-index by label if applicable
if dimensions_info.feature_identifier == FeatureSampleIdentifier.label:
labels_by_id = get_dataset_feature_labels_by_id(db, user, dataset)
df = df.rename(columns=labels_by_id)

if dimensions_info.sample_identifier == FeatureSampleIdentifier.label:
label_by_id = get_dataset_sample_labels_by_id(db, user, dataset)
df = df.rename(index=label_by_id)

return df
Loading
Loading