Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

filter: Split filtering and subsampling #1432

Draft
wants to merge 5 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions DEPRECATED.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@ These features are deprecated, which means they are no longer maintained and
will go away in a future major version of Augur. They are currently still
available for backwards compatibility, but should not be used in new code.

## `augur filter --metadata-chunk-size`

*Deprecated in version X.X.X. Planned for removal in September 2024 or after.*

FIXME: add description here

## `augur parse` preference of `name` over `strain` as the sequence ID field

*Deprecated in February 2024. Planned to be reordered June 2024 or after.*
Expand Down
10 changes: 7 additions & 3 deletions augur/filter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Filter and subsample a sequence set.
"""
from augur.dates import numeric_date_type, SUPPORTED_DATE_HELP_TEXT
from augur.filter.io import ACCEPTED_TYPES, column_type_pair
from augur.filter.io import ACCEPTED_TYPES, cleanup_outputs, column_type_pair
from augur.io.metadata import DEFAULT_DELIMITERS, DEFAULT_ID_COLUMNS, METADATA_DATE_COLUMN
from augur.types import EmptyOutputReportingMethod
from . import constants
Expand All @@ -18,7 +18,7 @@ def register_arguments(parser):
input_group.add_argument('--metadata', required=True, metavar="FILE", help="sequence metadata")
input_group.add_argument('--sequences', '-s', help="sequences in FASTA or VCF format")
input_group.add_argument('--sequence-index', help="sequence composition report generated by augur index. If not provided, an index will be created on the fly.")
input_group.add_argument('--metadata-chunk-size', type=int, default=100000, help="maximum number of metadata records to read into memory at a time. Increasing this number can speed up filtering at the cost of more memory used.")
input_group.add_argument('--metadata-chunk-size', help="[DEPRECATED] Previously used to specify maximum number of metadata records to read into memory at a time. This no longer has an effect.")
input_group.add_argument('--metadata-id-columns', default=DEFAULT_ID_COLUMNS, nargs="+", help="names of possible metadata columns containing identifier information, ordered by priority. Only one ID column will be inferred.")
input_group.add_argument('--metadata-delimiters', default=DEFAULT_DELIMITERS, nargs="+", help="delimiters to accept when reading a metadata file. Only one delimiter will be inferred.")

Expand Down Expand Up @@ -104,4 +104,8 @@ def run(args):
validate_arguments(args)

from ._run import run as _run
return _run(args)
try:
return _run(args)
except:
cleanup_outputs(args)
raise
344 changes: 101 additions & 243 deletions augur/filter/_run.py
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

52fbad6 passes all Cram tests but fails on the ncov trial run:

  File "/nextstrain/augur/augur/io/metadata.py", line 155, in read_metadata
    return pd.read_csv(
  …
OSError: [Errno 12] Cannot allocate memory

This means that given its current and growing size, the SARS-CoV-2 dataset is impractical to load into memory all at once despite optimizations such as loading a subset of columns and 8d7206c. The only solution for more modular and portable functions in the codebase is an on-disk approach, which Pandas does not support. There is a spectrum of alternatives, which I think can be divided into two categories:

  1. Pandas-like alternative such as Dask. Unsure how portable the existing Pandas logic is to Dask, but ideally this would be close to a library swap with minimal code change.
  2. Database file approach such as SQLite. Started in filter: Rewrite using SQLite3 #1242. It's essentially a rewrite of augur filter, requiring extensive testing. The other downside is this will require some form of Pandas roundtripping to continue supporting the widely-used Pandas-based queries in --query.

I think it's reasonable to explore (1) on top of current changes in this PR to see if it's a viable option. Even if the end goal is (2), (1) would provide a good stepping stone in the direction of the database approach.

Large diffs are not rendered by default.

239 changes: 95 additions & 144 deletions augur/filter/subsample.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,88 @@
import heapq
import itertools
import uuid
import numpy as np
import pandas as pd
from collections import defaultdict
from typing import Collection

from augur.dates import get_iso_year_week
from augur.errors import AugurError
from augur.filter.io import read_priority_scores
from augur.io.metadata import METADATA_DATE_COLUMN
from augur.io.print import print_err
from . import constants


def get_groups_for_subsampling(strains, metadata, group_by=None):
"""Return a list of groups for each given strain based on the corresponding
metadata and group by column.
def subsample(metadata, args, group_by):

# Use user-defined priorities, if possible. Otherwise, setup a
# corresponding dictionary that returns a random float for each strain.
if args.priority:
priorities = read_priority_scores(args.priority)
else:
random_generator = np.random.default_rng(args.subsample_seed)
priorities = defaultdict(random_generator.random)

# Generate columns for grouping.
grouping_metadata = enrich_metadata(
metadata,
group_by,
)

# Enrich with priorities.
grouping_metadata['priority'] = [priorities[strain] for strain in grouping_metadata.index]

pandas_groupby = grouping_metadata.groupby(list(group_by), group_keys=False)

n_groups = len(pandas_groupby.groups)

# Determine sequences per group.
if args.sequences_per_group:
sequences_per_group = args.sequences_per_group
elif args.subsample_max_sequences:
group_sizes = [len(strains) for strains in pandas_groupby.groups.values()]

try:
# Calculate sequences per group. If there are more groups than maximum
# sequences requested, sequences per group will be a floating point
# value and subsampling will be probabilistic.
sequences_per_group, probabilistic_used = calculate_sequences_per_group(
args.subsample_max_sequences,
group_sizes,
allow_probabilistic=args.probabilistic_sampling
)
except TooManyGroupsError as error:
raise AugurError(str(error)) from error

if (probabilistic_used):
print(f"Sampling probabilistically at {sequences_per_group:0.4f} sequences per group, meaning it is possible to have more than the requested maximum of {args.subsample_max_sequences} sequences after filtering.")
else:
print(f"Sampling at {sequences_per_group} per group.")
else:
pass
# FIXME: what to do when no subsampling is requested?

group_size_limits = (size for size in get_group_size_limits(n_groups, sequences_per_group, random_seed=args.subsample_seed))

def row_sampler(group):
n = next(group_size_limits)
return group.nlargest(n, 'priority')

return {strain for strain in pandas_groupby.apply(row_sampler).index}

def enrich_metadata(metadata, group_by=None):
"""Enrich metadata with generated columns.

Parameters
----------
strains : list
A list of strains to get groups for.
metadata : pandas.DataFrame
Metadata to inspect for the given strains.
group_by : list
A list of metadata (or generated) columns to group records by.

Returns
-------
dict :
A mapping of strain names to tuples corresponding to the values of the strain's group.
metadata : pandas.DataFrame
Metadata with generated columns.

Examples
--------
Expand Down Expand Up @@ -75,15 +129,12 @@ def get_groups_for_subsampling(strains, metadata, group_by=None):
>>> get_groups_for_subsampling(strains, metadata, group_by=('_dummy',))
{'strain1': ('_dummy',), 'strain2': ('_dummy',)}
"""
metadata = metadata.loc[list(strains)]
group_by_strain = {}

if len(metadata) == 0:
return group_by_strain
return metadata

if not group_by or group_by == ('_dummy',):
group_by_strain = {strain: ('_dummy',) for strain in strains}
return group_by_strain
metadata['_dummy'] = '_dummy'
return metadata

group_by_set = set(group_by)
generated_columns_requested = constants.GROUP_BY_GENERATED_COLUMNS & group_by_set
Expand Down Expand Up @@ -130,9 +181,10 @@ def get_groups_for_subsampling(strains, metadata, group_by=None):
# Drop the date column since it should not be used for grouping.
metadata = pd.concat([metadata.drop(METADATA_DATE_COLUMN, axis=1), df_dates], axis=1)

# FIXME: I think this is useless - drop it in another commit
# Check again if metadata is empty after dropping ambiguous dates.
if metadata.empty:
return group_by_strain
# if metadata.empty:
# return group_by_strain

# Generate columns.
if constants.DATE_YEAR_COLUMN in generated_columns_requested:
Expand Down Expand Up @@ -164,149 +216,48 @@ def get_groups_for_subsampling(strains, metadata, group_by=None):
for group in unknown_groups:
metadata[group] = 'unknown'

# Finally, determine groups.
group_by_strain = dict(zip(metadata.index, metadata[group_by].apply(tuple, axis=1)))
return group_by_strain


class PriorityQueue:
"""A priority queue implementation that automatically replaces lower priority
items in the heap with incoming higher priority items.

Examples
--------

Add a single record to a heap with a maximum of 2 records.

>>> queue = PriorityQueue(max_size=2)
>>> queue.add({"strain": "strain1"}, 0.5)
1

Add another record with a higher priority. The queue should be at its maximum
size.

>>> queue.add({"strain": "strain2"}, 1.0)
2
>>> queue.heap
[(0.5, 0, {'strain': 'strain1'}), (1.0, 1, {'strain': 'strain2'})]
>>> list(queue.get_items())
[{'strain': 'strain1'}, {'strain': 'strain2'}]
return metadata

Add a higher priority record that causes the queue to exceed its maximum
size. The resulting queue should contain the two highest priority records
after the lowest priority record is removed.

>>> queue.add({"strain": "strain3"}, 2.0)
2
>>> list(queue.get_items())
[{'strain': 'strain2'}, {'strain': 'strain3'}]

Add a record with the same priority as another record, forcing the duplicate
to be resolved by removing the oldest entry.

>>> queue.add({"strain": "strain4"}, 1.0)
2
>>> list(queue.get_items())
[{'strain': 'strain4'}, {'strain': 'strain3'}]

"""
def __init__(self, max_size):
"""Create a fixed size heap (priority queue)

"""
self.max_size = max_size
self.heap = []
self.counter = itertools.count()

def add(self, item, priority):
"""Add an item to the queue with a given priority.

If adding the item causes the queue to exceed its maximum size, replace
the lowest priority item with the given item. The queue stores items
with an additional heap id value (a count) to resolve ties between items
with equal priority (favoring the most recently added item).

"""
heap_id = next(self.counter)

if len(self.heap) >= self.max_size:
heapq.heappushpop(self.heap, (priority, heap_id, item))
else:
heapq.heappush(self.heap, (priority, heap_id, item))

return len(self.heap)

def get_items(self):
"""Return each item in the queue in order.

Yields
------
Any
Item stored in the queue.

"""
for priority, heap_id, item in self.heap:
yield item


def create_queues_by_group(groups, max_size, max_attempts=100, random_seed=None):
"""Create a dictionary of priority queues per group for the given maximum size.
def get_group_size_limits(number_of_groups: int, max_size, max_attempts = 100, random_seed = None):
"""Return a list of group size limits.

When the maximum size is fractional, probabilistically sample the maximum
size from a Poisson distribution. Make at least the given number of maximum
attempts to create queues for which the sum of their maximum sizes is
attempts to create groups for which the sum of their maximum sizes is
greater than zero.

Examples
--------

Create queues for two groups with a fixed maximum size.

>>> groups = ("2015", "2016")
>>> queues = create_queues_by_group(groups, 2)
>>> sum(queue.max_size for queue in queues.values())
4

Create queues for two groups with a fractional maximum size. Their total max
size should still be an integer value greater than zero.

>>> seed = 314159
>>> queues = create_queues_by_group(groups, 0.1, random_seed=seed)
>>> int(sum(queue.max_size for queue in queues.values())) > 0
True

A subsequent run of this function with the same groups and random seed
should produce the same queues and queue sizes.

>>> more_queues = create_queues_by_group(groups, 0.1, random_seed=seed)
>>> [queue.max_size for queue in queues.values()] == [queue.max_size for queue in more_queues.values()]
True

Parameters
----------
number_of_groups : int
The number of groups.
max_size : int | float
Maximum size of a group.
max_attempts : int
Maximum number of attempts for creating group sizes.
random_seed
Seed value for np.random.default_rng for reproducible randomness.
"""
queues_by_group = {}
sizes = None
total_max_size = 0
attempts = 0

if max_size < 1.0:
random_generator = np.random.default_rng(random_seed)
# If max_size is not fractional, use it as the limit for all groups.
if int(max_size) == max_size:
return np.full(number_of_groups, max_size)

# For small fractional maximum sizes, it is possible to randomly select
# maximum queue sizes that all equal zero. When this happens, filtering
# fails unexpectedly. We make multiple attempts to create queues with
# maximum sizes greater than zero for at least one queue.
# maximum sizes that all equal zero. When this happens, filtering
# fails unexpectedly. We make multiple attempts to create sizes with
# maximum sizes greater than zero for at least one group.
while total_max_size == 0 and attempts < max_attempts:
for group in sorted(groups):
if max_size < 1.0:
queue_max_size = random_generator.poisson(max_size)
else:
queue_max_size = max_size

queues_by_group[group] = PriorityQueue(queue_max_size)

total_max_size = sum(queue.max_size for queue in queues_by_group.values())
sizes = np.random.default_rng(random_seed).poisson(max_size, size=number_of_groups)
total_max_size = sum(sizes)
attempts += 1

return queues_by_group
assert sizes is not None

return sizes


def calculate_sequences_per_group(target_max_value, group_sizes, allow_probabilistic=True):
Expand Down
4 changes: 4 additions & 0 deletions augur/filter/validate_arguments.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from augur.errors import AugurError
from augur.io.print import print_err
from augur.io.vcf import is_vcf as filename_is_vcf


Expand All @@ -16,6 +17,9 @@ def validate_arguments(args):
args : argparse.Namespace
Parsed arguments from argparse
"""
if args.metadata_chunk_size:
print_err("WARNING: --metadata-chunk-size is no longer necessary and will be removed in a future version.")

# Don't allow sequence output when no sequence input is provided.
if args.output and not args.sequences:
raise AugurError("You need to provide sequences to output sequences.")
Expand Down
2 changes: 1 addition & 1 deletion augur/io/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def read_metadata(metadata_file, delimiters=DEFAULT_DELIMITERS, columns=None, id

if isinstance(dtype, dict):
# Avoid reading numerical IDs as integers.
dtype["index_col"] = "string"
dtype[index_col] = "string"

# Avoid reading year-only dates as integers.
dtype[METADATA_DATE_COLUMN] = "string"
Expand Down
Loading
Loading