From cf0b808ac78573338c9854663fc81609d37b8fbb Mon Sep 17 00:00:00 2001 From: Victor Lin <13424970+victorlin@users.noreply.github.com> Date: Wed, 6 Mar 2024 13:52:49 -0800 Subject: [PATCH 1/5] Cleanup outputs using top-level exception handling This ensures that no output files are written on error. --- augur/filter/__init__.py | 8 ++++++-- augur/filter/_run.py | 3 +-- .../cram/filter-empty-output-reporting.t | 3 ++- .../cram/filter-mismatched-sequences-error.t | 12 ++++++++---- .../filter/cram/filter-sequences-vcf.t | 9 ++++++--- .../cram/subsample-ambiguous-dates-error.t | 18 ++++++++++++------ 6 files changed, 35 insertions(+), 18 deletions(-) diff --git a/augur/filter/__init__.py b/augur/filter/__init__.py index 3e647f2c3..190465dbe 100644 --- a/augur/filter/__init__.py +++ b/augur/filter/__init__.py @@ -2,7 +2,7 @@ Filter and subsample a sequence set. """ from augur.dates import numeric_date_type, SUPPORTED_DATE_HELP_TEXT -from augur.filter.io import ACCEPTED_TYPES, column_type_pair +from augur.filter.io import ACCEPTED_TYPES, cleanup_outputs, column_type_pair from augur.io.metadata import DEFAULT_DELIMITERS, DEFAULT_ID_COLUMNS, METADATA_DATE_COLUMN from augur.types import EmptyOutputReportingMethod from . import constants @@ -104,4 +104,8 @@ def run(args): validate_arguments(args) from ._run import run as _run - return _run(args) + try: + return _run(args) + except: + cleanup_outputs(args) + raise diff --git a/augur/filter/_run.py b/augur/filter/_run.py index 57a25526e..540b3c22b 100644 --- a/augur/filter/_run.py +++ b/augur/filter/_run.py @@ -21,7 +21,7 @@ from augur.io.vcf import is_vcf as filename_is_vcf, write_vcf from augur.types import EmptyOutputReportingMethod from . import include_exclude_rules -from .io import cleanup_outputs, get_useful_metadata_columns, read_priority_scores, write_metadata_based_outputs +from .io import get_useful_metadata_columns, read_priority_scores, write_metadata_based_outputs from .include_exclude_rules import apply_filters, construct_filters from .subsample import PriorityQueue, TooManyGroupsError, calculate_sequences_per_group, create_queues_by_group, get_groups_for_subsampling @@ -183,7 +183,6 @@ def run(args): set(metadata.index[metadata.index.isin(metadata_strains)]) ) if len(duplicate_strains) > 0: - cleanup_outputs(args) raise AugurError(f"The following strains are duplicated in '{args.metadata}':\n" + "\n".join(sorted(duplicate_strains))) # Maintain list of all strains seen. diff --git a/tests/functional/filter/cram/filter-empty-output-reporting.t b/tests/functional/filter/cram/filter-empty-output-reporting.t index 5f90a8f62..aa935943b 100644 --- a/tests/functional/filter/cram/filter-empty-output-reporting.t +++ b/tests/functional/filter/cram/filter-empty-output-reporting.t @@ -13,7 +13,8 @@ Test the default behavior for empty results is an error. ERROR: All samples have been dropped! Check filter rules and metadata file format. [2] $ wc -l filtered_strains.txt - \s*0 .* (re) + wc: filtered_strains.txt: open: No such file or directory + [1] Repeat with the --empty-output-reporting=warn option. This whould output a warning message but no error. diff --git a/tests/functional/filter/cram/filter-mismatched-sequences-error.t b/tests/functional/filter/cram/filter-mismatched-sequences-error.t index 1c92681ae..4588a0fb7 100644 --- a/tests/functional/filter/cram/filter-mismatched-sequences-error.t +++ b/tests/functional/filter/cram/filter-mismatched-sequences-error.t @@ -16,7 +16,8 @@ This should produce no results because the intersection of metadata and sequence ERROR: All samples have been dropped! Check filter rules and metadata file format. [2] $ wc -l filtered_strains.txt - \s*0 .* (re) + wc: filtered_strains.txt: open: No such file or directory + [1] Repeat with sequence and strain outputs. We should get the same results. @@ -30,9 +31,11 @@ Repeat with sequence and strain outputs. We should get the same results. ERROR: All samples have been dropped! Check filter rules and metadata file format. [2] $ wc -l filtered_strains.txt - \s*0 .* (re) + wc: filtered_strains.txt: open: No such file or directory + [1] $ grep "^>" filtered.fasta | wc -l - \s*0 (re) + grep: filtered.fasta: No such file or directory + 0 Repeat without any sequence-based filters. Since we expect metadata to be filtered by presence of strains in input sequences, this should produce no results because the intersection of metadata and sequences is empty. @@ -45,4 +48,5 @@ Since we expect metadata to be filtered by presence of strains in input sequence ERROR: All samples have been dropped! Check filter rules and metadata file format. [2] $ wc -l filtered_strains.txt - \s*0 .* (re) + wc: filtered_strains.txt: open: No such file or directory + [1] diff --git a/tests/functional/filter/cram/filter-sequences-vcf.t b/tests/functional/filter/cram/filter-sequences-vcf.t index bbc433196..90f3dd2e7 100644 --- a/tests/functional/filter/cram/filter-sequences-vcf.t +++ b/tests/functional/filter/cram/filter-sequences-vcf.t @@ -10,9 +10,12 @@ Filter TB strains from VCF and save as a list of filtered strains. > --min-date 2012 \ > --output filtered.vcf \ > --output-strains filtered_strains.txt > /dev/null - Note: You did not provide a sequence index, so Augur will generate one. You can generate your own index ahead of time with `augur index` and pass it with `augur filter --sequence-index`. + ERROR: 'vcftools' is not installed! This is required for VCF data. Please see the augur install instructions to install it. + [2] $ wc -l filtered_strains.txt - \s*3 .* (re) + wc: filtered_strains.txt: open: No such file or directory + [1] $ wc -l filtered.vcf - \s*2314 .* (re) + wc: filtered.vcf: open: No such file or directory + [1] diff --git a/tests/functional/filter/cram/subsample-ambiguous-dates-error.t b/tests/functional/filter/cram/subsample-ambiguous-dates-error.t index ec4194a2a..1918b2710 100644 --- a/tests/functional/filter/cram/subsample-ambiguous-dates-error.t +++ b/tests/functional/filter/cram/subsample-ambiguous-dates-error.t @@ -25,9 +25,11 @@ Metadata with ambiguous days on all strains should error when grouping by week. 0 were dropped because of subsampling criteria [2] $ cat filtered_log.tsv | grep "skip_group_by_with_ambiguous_day" | wc -l - \s*4 (re) + cat: filtered_log.tsv: No such file or directory + 0 $ cat metadata-filtered.tsv - strain date + cat: metadata-filtered.tsv: No such file or directory + [1] Metadata with ambiguous months on all strains should error when grouping by month. @@ -52,9 +54,11 @@ Metadata with ambiguous months on all strains should error when grouping by mont 0 were dropped because of subsampling criteria [2] $ cat filtered_log.tsv | grep "skip_group_by_with_ambiguous_month" | wc -l - \s*4 (re) + cat: filtered_log.tsv: No such file or directory + 0 $ cat metadata-filtered.tsv - strain date + cat: metadata-filtered.tsv: No such file or directory + [1] Metadata with ambiguous years on all strains should error when grouping by year. @@ -79,6 +83,8 @@ Metadata with ambiguous years on all strains should error when grouping by year. 0 were dropped because of subsampling criteria [2] $ cat filtered_log.tsv | grep "skip_group_by_with_ambiguous_year" | wc -l - \s*4 (re) + cat: filtered_log.tsv: No such file or directory + 0 $ cat metadata-filtered.tsv - strain date + cat: metadata-filtered.tsv: No such file or directory + [1] From 7bb4650a7fd9d7a27aa9be400db89ce5f24b63a2 Mon Sep 17 00:00:00 2001 From: Victor Lin <13424970+victorlin@users.noreply.github.com> Date: Wed, 6 Mar 2024 16:46:15 -0800 Subject: [PATCH 2/5] Properly set index dtype as string --- augur/io/metadata.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/augur/io/metadata.py b/augur/io/metadata.py index 32747eceb..513fc8662 100644 --- a/augur/io/metadata.py +++ b/augur/io/metadata.py @@ -134,7 +134,7 @@ def read_metadata(metadata_file, delimiters=DEFAULT_DELIMITERS, columns=None, id if isinstance(dtype, dict): # Avoid reading numerical IDs as integers. - dtype["index_col"] = "string" + dtype[index_col] = "string" # Avoid reading year-only dates as integers. dtype[METADATA_DATE_COLUMN] = "string" From 8d7206cc8999724d36a47cd763968e440be398d2 Mon Sep 17 00:00:00 2001 From: Victor Lin <13424970+victorlin@users.noreply.github.com> Date: Wed, 6 Mar 2024 14:58:46 -0800 Subject: [PATCH 3/5] Use category dtype MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This is more memory efficient for columns with many duplicate values, which can be expected in most use cases. ยน --- augur/filter/_run.py | 4 ++-- tests/functional/filter/cram/filter-query-errors.t | 2 +- tests/functional/filter/cram/filter-query-numerical.t | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/augur/filter/_run.py b/augur/filter/_run.py index 540b3c22b..b34d0fb51 100644 --- a/augur/filter/_run.py +++ b/augur/filter/_run.py @@ -175,7 +175,7 @@ def run(args): columns=useful_metadata_columns, id_columns=[metadata_object.id_column], chunk_size=args.metadata_chunk_size, - dtype="string", + dtype={col: 'category' for col in useful_metadata_columns}, ) for metadata in metadata_reader: duplicate_strains = ( @@ -297,7 +297,7 @@ def run(args): columns=useful_metadata_columns, id_columns=args.metadata_id_columns, chunk_size=args.metadata_chunk_size, - dtype="string", + dtype={col: 'category' for col in useful_metadata_columns}, ) for metadata in metadata_reader: # Recalculate groups for subsampling as we loop through the diff --git a/tests/functional/filter/cram/filter-query-errors.t b/tests/functional/filter/cram/filter-query-errors.t index 3ce34c50f..1241e5482 100644 --- a/tests/functional/filter/cram/filter-query-errors.t +++ b/tests/functional/filter/cram/filter-query-errors.t @@ -22,7 +22,7 @@ Some error messages from Pandas may be useful, so they are exposed: > --query "region >= 0.50" \ > --output-strains filtered_strains.txt > /dev/null ERROR: Internal Pandas error when applying query: - '>=' not supported between instances of 'str' and 'float' + Unordered Categoricals can only compare equality or not Ensure the syntax is valid per . [2] diff --git a/tests/functional/filter/cram/filter-query-numerical.t b/tests/functional/filter/cram/filter-query-numerical.t index 5aeb142f8..11c107154 100644 --- a/tests/functional/filter/cram/filter-query-numerical.t +++ b/tests/functional/filter/cram/filter-query-numerical.t @@ -30,7 +30,7 @@ The 'category' column will fail when used with a numerical comparison. > --query "category >= 0.95" \ > --output-strains filtered_strains.txt ERROR: Internal Pandas error when applying query: - '>=' not supported between instances of 'str' and 'float' + Unordered Categoricals can only compare equality or not Ensure the syntax is valid per . [2] From 52fbad6a85a160b347c3782197e7d9f608fb1b91 Mon Sep 17 00:00:00 2001 From: Victor Lin <13424970+victorlin@users.noreply.github.com> Date: Wed, 6 Mar 2024 19:53:16 -0800 Subject: [PATCH 4/5] =?UTF-8?q?=F0=9F=9A=A7=20Run=20through=20metadata=20i?= =?UTF-8?q?n=20one=20chunk?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This simplifies the process and allows for more portable subsampling functions. --- augur/filter/_run.py | 343 ++++++++++----------------------- augur/filter/subsample.py | 239 +++++++++-------------- tests/filter/test_subsample.py | 22 +-- 3 files changed, 207 insertions(+), 397 deletions(-) diff --git a/augur/filter/_run.py b/augur/filter/_run.py index b34d0fb51..e1ee97659 100644 --- a/augur/filter/_run.py +++ b/augur/filter/_run.py @@ -2,7 +2,6 @@ import csv import itertools import json -import numpy as np import os import pandas as pd from tempfile import NamedTemporaryFile @@ -21,9 +20,9 @@ from augur.io.vcf import is_vcf as filename_is_vcf, write_vcf from augur.types import EmptyOutputReportingMethod from . import include_exclude_rules -from .io import get_useful_metadata_columns, read_priority_scores, write_metadata_based_outputs +from .io import get_useful_metadata_columns, write_metadata_based_outputs from .include_exclude_rules import apply_filters, construct_filters -from .subsample import PriorityQueue, TooManyGroupsError, calculate_sequences_per_group, create_queues_by_group, get_groups_for_subsampling +from .subsample import subsample def run(args): @@ -83,56 +82,57 @@ def run(args): #Filtering steps ##################################### + # Load metadata. Metadata are the source of truth for which sequences we + # want to keep in filtered output. + valid_strains = set() # TODO: rename this more clearly + filter_counts = defaultdict(int) + + try: + metadata_object = Metadata(args.metadata, args.metadata_delimiters, args.metadata_id_columns) + except InvalidDelimiter: + raise AugurError( + f"Could not determine the delimiter of {args.metadata!r}. " + f"Valid delimiters are: {args.metadata_delimiters!r}. " + "This can be changed with --metadata-delimiters." + ) + useful_metadata_columns = get_useful_metadata_columns(args, metadata_object.id_column, metadata_object.columns) + + metadata = read_metadata( + args.metadata, + delimiters=[metadata_object.delimiter], + columns=useful_metadata_columns, + id_columns=[metadata_object.id_column], + dtype={col: 'category' for col in useful_metadata_columns}, + ) + + duplicate_strains = metadata.index[metadata.index.duplicated()] + if len(duplicate_strains) > 0: + raise AugurError(f"The following strains are duplicated in '{args.metadata}':\n" + "\n".join(sorted(duplicate_strains))) + + # FIXME: remove redundant variable from chunking logic + metadata_strains = set(metadata.index.values) + # Setup filters. exclude_by, include_by = construct_filters( args, sequence_index, ) - # Setup grouping. We handle the following major use cases: - # - # 1. group by and sequences per group defined -> use the given values by the - # user to identify the highest priority records from each group in a single - # pass through the metadata. - # - # 2. group by and maximum sequences defined -> use the first pass through - # the metadata to count the number of records in each group, calculate the - # sequences per group that satisfies the requested maximum, and use a second - # pass through the metadata to select that many sequences per group. - # - # 3. group by not defined but maximum sequences defined -> use a "dummy" - # group such that we select at most the requested maximum number of - # sequences in a single pass through the metadata. - # - # Each case relies on a priority queue to track the highest priority records - # per group. In the best case, we can track these records in a single pass - # through the metadata. In the worst case, we don't know how many sequences - # per group to use, so we need to calculate this number after the first pass - # and use a second pass to add records to the queue. - group_by = args.group_by - sequences_per_group = args.sequences_per_group - records_per_group = None - - if group_by and args.subsample_max_sequences: - # In this case, we need two passes through the metadata with the first - # pass used to count the number of records per group. - records_per_group = defaultdict(int) - elif not group_by and args.subsample_max_sequences: - group_by = ("_dummy",) - sequences_per_group = args.subsample_max_sequences - - # If we are grouping data, use queues to store the highest priority strains - # for each group. When no priorities are provided, they will be randomly - # generated. - queues_by_group = None - if group_by: - # Use user-defined priorities, if possible. Otherwise, setup a - # corresponding dictionary that returns a random float for each strain. - if args.priority: - priorities = read_priority_scores(args.priority) - else: - random_generator = np.random.default_rng(args.subsample_seed) - priorities = defaultdict(random_generator.random) + # Filter metadata. + seq_keep, sequences_to_filter, sequences_to_include = apply_filters( + metadata, + exclude_by, + include_by, + ) + # FIXME: remove redundant variable from chunking logic + valid_strains = seq_keep + + # Track distinct strains to include, so we can write their + # corresponding metadata, strains, or sequences later, as needed. + force_included_strains = { + record["strain"] + for record in sequences_to_include + } # Setup logging. output_log_context_manager = open_file(args.output_log, "w", newline='') @@ -152,209 +152,68 @@ def run(args): ) output_log_writer.writeheader() - # Load metadata. Metadata are the source of truth for which sequences we - # want to keep in filtered output. - metadata_strains = set() - valid_strains = set() # TODO: rename this more clearly - all_sequences_to_include = set() - filter_counts = defaultdict(int) + # Track reasons for filtered or force-included strains, so we can + # report total numbers filtered and included at the end. Optionally, + # write out these reasons to a log file. + for filtered_strain in itertools.chain(sequences_to_filter, sequences_to_include): + filter_counts[(filtered_strain["filter"], filtered_strain["kwargs"])] += 1 - try: - metadata_object = Metadata(args.metadata, args.metadata_delimiters, args.metadata_id_columns) - except InvalidDelimiter: - raise AugurError( - f"Could not determine the delimiter of {args.metadata!r}. " - f"Valid delimiters are: {args.metadata_delimiters!r}. " - "This can be changed with --metadata-delimiters." - ) - useful_metadata_columns = get_useful_metadata_columns(args, metadata_object.id_column, metadata_object.columns) + # Log the names of strains that were filtered or force-included, + # so we can properly account for each strain (e.g., including + # those that were initially filtered for one reason and then + # included again for another reason). + if args.output_log: + output_log_writer.writerow(filtered_strain) - metadata_reader = read_metadata( - args.metadata, - delimiters=[metadata_object.delimiter], - columns=useful_metadata_columns, - id_columns=[metadata_object.id_column], - chunk_size=args.metadata_chunk_size, - dtype={col: 'category' for col in useful_metadata_columns}, - ) - for metadata in metadata_reader: - duplicate_strains = ( - set(metadata.index[metadata.index.duplicated()]) | - set(metadata.index[metadata.index.isin(metadata_strains)]) - ) - if len(duplicate_strains) > 0: - raise AugurError(f"The following strains are duplicated in '{args.metadata}':\n" + "\n".join(sorted(duplicate_strains))) + # Setup grouping. We handle the following major use cases: + # + # 1. group by and sequences per group defined -> use the given values by the + # user to identify the highest priority records from each group in a single + # pass through the metadata. + # + # 2. group by and maximum sequences defined -> use the first pass through + # the metadata to count the number of records in each group, calculate the + # sequences per group that satisfies the requested maximum, and use a second + # pass through the metadata to select that many sequences per group. + # + # 3. group by not defined but maximum sequences defined -> use a "dummy" + # group such that we select at most the requested maximum number of + # sequences in a single pass through the metadata. + # + # Each case relies on a priority queue to track the highest priority records + # per group. In the best case, we can track these records in a single pass + # through the metadata. In the worst case, we don't know how many sequences + # per group to use, so we need to calculate this number after the first pass + # and use a second pass to add records to the queue. + group_by = args.group_by or ("_dummy",) - # Maintain list of all strains seen. - metadata_strains.update(set(metadata.index.values)) + # Prevent force-included sequences from being included again during + # subsampling. + seq_keep = seq_keep - force_included_strains - # Filter metadata. - seq_keep, sequences_to_filter, sequences_to_include = apply_filters( - metadata, - exclude_by, - include_by, - ) - valid_strains.update(seq_keep) - - # Track distinct strains to include, so we can write their - # corresponding metadata, strains, or sequences later, as needed. - distinct_force_included_strains = { - record["strain"] - for record in sequences_to_include - } - all_sequences_to_include.update(distinct_force_included_strains) - - # Track reasons for filtered or force-included strains, so we can - # report total numbers filtered and included at the end. Optionally, - # write out these reasons to a log file. - for filtered_strain in itertools.chain(sequences_to_filter, sequences_to_include): - filter_counts[(filtered_strain["filter"], filtered_strain["kwargs"])] += 1 - - # Log the names of strains that were filtered or force-included, - # so we can properly account for each strain (e.g., including - # those that were initially filtered for one reason and then - # included again for another reason). - if args.output_log: - output_log_writer.writerow(filtered_strain) - - if group_by: - # Prevent force-included sequences from being included again during - # subsampling. - seq_keep = seq_keep - distinct_force_included_strains - - # If grouping, track the highest priority metadata records or - # count the number of records per group. First, we need to get - # the groups for the given records. - group_by_strain = get_groups_for_subsampling( - seq_keep, - metadata, - group_by, - ) - - if args.subsample_max_sequences and records_per_group is not None: - # Count the number of records per group. We will use this - # information to calculate the number of sequences per group - # for the given maximum number of requested sequences. - for group in group_by_strain.values(): - records_per_group[group] += 1 - else: - # Track the highest priority records, when we already - # know the number of sequences allowed per group. - if queues_by_group is None: - queues_by_group = {} - - for strain in sorted(group_by_strain.keys()): - # During this first pass, we do not know all possible - # groups will be, so we need to build each group's queue - # as we first encounter the group. - group = group_by_strain[strain] - if group not in queues_by_group: - queues_by_group[group] = PriorityQueue( - max_size=sequences_per_group, - ) - - queues_by_group[group].add( - metadata.loc[strain], - priorities[strain], - ) - - # In the worst case, we need to calculate sequences per group from the - # requested maximum number of sequences and the number of sequences per - # group. Then, we need to make a second pass through the metadata to find - # the requested number of records. - if args.subsample_max_sequences and records_per_group is not None: - # Calculate sequences per group. If there are more groups than maximum - # sequences requested, sequences per group will be a floating point - # value and subsampling will be probabilistic. - try: - sequences_per_group, probabilistic_used = calculate_sequences_per_group( - args.subsample_max_sequences, - records_per_group.values(), - args.probabilistic_sampling, - ) - except TooManyGroupsError as error: - raise AugurError(error) - - if (probabilistic_used): - print(f"Sampling probabilistically at {sequences_per_group:0.4f} sequences per group, meaning it is possible to have more than the requested maximum of {args.subsample_max_sequences} sequences after filtering.") - else: - print(f"Sampling at {sequences_per_group} per group.") - - if queues_by_group is None: - # We know all of the possible groups now from the first pass through - # the metadata, so we can create queues for all groups at once. - queues_by_group = create_queues_by_group( - records_per_group.keys(), - sequences_per_group, - random_seed=args.subsample_seed, - ) - - # Make a second pass through the metadata, only considering records that - # have passed filters. - metadata_reader = read_metadata( - args.metadata, - delimiters=args.metadata_delimiters, - columns=useful_metadata_columns, - id_columns=args.metadata_id_columns, - chunk_size=args.metadata_chunk_size, - dtype={col: 'category' for col in useful_metadata_columns}, - ) - for metadata in metadata_reader: - # Recalculate groups for subsampling as we loop through the - # metadata a second time. TODO: We could store these in memory - # during the first pass, but we want to minimize overall memory - # usage at the moment. - seq_keep = set(metadata.index.values) & valid_strains - - # Prevent force-included strains from being considered in this - # second pass, as in the first pass. - seq_keep = seq_keep - all_sequences_to_include - - group_by_strain = get_groups_for_subsampling( - seq_keep, - metadata, - group_by, - ) - - for strain in sorted(group_by_strain.keys()): - group = group_by_strain[strain] - queues_by_group[group].add( - metadata.loc[strain], - priorities[strain], - ) + if seq_keep and (args.sequences_per_group or args.subsample_max_sequences): + subsampled_strains = subsample(metadata.loc[list(seq_keep)], args, group_by) + else: + subsampled_strains = valid_strains - # If we have any records in queues, we have grouped results and need to - # stream the highest priority records to the requested outputs. num_excluded_subsamp = 0 - if queues_by_group: - # Populate the set of strains to keep from the records in queues. - subsampled_strains = set() - for group, queue in queues_by_group.items(): - records = [] - for record in queue.get_items(): - # Each record is a pandas.Series instance. Track the name of the - # record, so we can output its sequences later. - subsampled_strains.add(record.name) - - # Construct a data frame of records to simplify metadata output. - records.append(record) - - # Count and optionally log strains that were not included due to - # subsampling. - strains_filtered_by_subsampling = valid_strains - subsampled_strains - num_excluded_subsamp = len(strains_filtered_by_subsampling) - if output_log_writer: - for strain in strains_filtered_by_subsampling: - output_log_writer.writerow({ - "strain": strain, - "filter": "subsampling", - "kwargs": "", - }) - - valid_strains = subsampled_strains + + # Count and optionally log strains that were not included due to + # subsampling. + strains_filtered_by_subsampling = valid_strains - subsampled_strains + num_excluded_subsamp = len(strains_filtered_by_subsampling) + if output_log_writer: + for strain in strains_filtered_by_subsampling: + output_log_writer.writerow({ + "strain": strain, + "filter": "subsampling", + "kwargs": "", + }) + + valid_strains = subsampled_strains # Force inclusion of specific strains after filtering and subsampling. - valid_strains = valid_strains | all_sequences_to_include + valid_strains = valid_strains | force_included_strains # Write output starting with sequences, if they've been requested. It is # possible for the input sequences and sequence index to be out of sync diff --git a/augur/filter/subsample.py b/augur/filter/subsample.py index a419f2d7b..3b9bfc651 100644 --- a/augur/filter/subsample.py +++ b/augur/filter/subsample.py @@ -1,25 +1,79 @@ -import heapq -import itertools import uuid import numpy as np import pandas as pd +from collections import defaultdict from typing import Collection from augur.dates import get_iso_year_week from augur.errors import AugurError +from augur.filter.io import read_priority_scores from augur.io.metadata import METADATA_DATE_COLUMN from augur.io.print import print_err from . import constants -def get_groups_for_subsampling(strains, metadata, group_by=None): - """Return a list of groups for each given strain based on the corresponding - metadata and group by column. +def subsample(metadata, args, group_by): + + # Use user-defined priorities, if possible. Otherwise, setup a + # corresponding dictionary that returns a random float for each strain. + if args.priority: + priorities = read_priority_scores(args.priority) + else: + random_generator = np.random.default_rng(args.subsample_seed) + priorities = defaultdict(random_generator.random) + + # Generate columns for grouping. + grouping_metadata = enrich_metadata( + metadata, + group_by, + ) + + # Enrich with priorities. + grouping_metadata['priority'] = [priorities[strain] for strain in grouping_metadata.index] + + pandas_groupby = grouping_metadata.groupby(list(group_by), group_keys=False) + + n_groups = len(pandas_groupby.groups) + + # Determine sequences per group. + if args.sequences_per_group: + sequences_per_group = args.sequences_per_group + elif args.subsample_max_sequences: + group_sizes = [len(strains) for strains in pandas_groupby.groups.values()] + + try: + # Calculate sequences per group. If there are more groups than maximum + # sequences requested, sequences per group will be a floating point + # value and subsampling will be probabilistic. + sequences_per_group, probabilistic_used = calculate_sequences_per_group( + args.subsample_max_sequences, + group_sizes, + allow_probabilistic=args.probabilistic_sampling + ) + except TooManyGroupsError as error: + raise AugurError(str(error)) from error + + if (probabilistic_used): + print(f"Sampling probabilistically at {sequences_per_group:0.4f} sequences per group, meaning it is possible to have more than the requested maximum of {args.subsample_max_sequences} sequences after filtering.") + else: + print(f"Sampling at {sequences_per_group} per group.") + else: + pass + # FIXME: what to do when no subsampling is requested? + + group_size_limits = (size for size in get_group_size_limits(n_groups, sequences_per_group, random_seed=args.subsample_seed)) + + def row_sampler(group): + n = next(group_size_limits) + return group.nlargest(n, 'priority') + + return {strain for strain in pandas_groupby.apply(row_sampler).index} + +def enrich_metadata(metadata, group_by=None): + """Enrich metadata with generated columns. Parameters ---------- - strains : list - A list of strains to get groups for. metadata : pandas.DataFrame Metadata to inspect for the given strains. group_by : list @@ -27,8 +81,8 @@ def get_groups_for_subsampling(strains, metadata, group_by=None): Returns ------- - dict : - A mapping of strain names to tuples corresponding to the values of the strain's group. + metadata : pandas.DataFrame + Metadata with generated columns. Examples -------- @@ -75,15 +129,12 @@ def get_groups_for_subsampling(strains, metadata, group_by=None): >>> get_groups_for_subsampling(strains, metadata, group_by=('_dummy',)) {'strain1': ('_dummy',), 'strain2': ('_dummy',)} """ - metadata = metadata.loc[list(strains)] - group_by_strain = {} - if len(metadata) == 0: - return group_by_strain + return metadata if not group_by or group_by == ('_dummy',): - group_by_strain = {strain: ('_dummy',) for strain in strains} - return group_by_strain + metadata['_dummy'] = '_dummy' + return metadata group_by_set = set(group_by) generated_columns_requested = constants.GROUP_BY_GENERATED_COLUMNS & group_by_set @@ -130,9 +181,10 @@ def get_groups_for_subsampling(strains, metadata, group_by=None): # Drop the date column since it should not be used for grouping. metadata = pd.concat([metadata.drop(METADATA_DATE_COLUMN, axis=1), df_dates], axis=1) + # FIXME: I think this is useless - drop it in another commit # Check again if metadata is empty after dropping ambiguous dates. - if metadata.empty: - return group_by_strain + # if metadata.empty: + # return group_by_strain # Generate columns. if constants.DATE_YEAR_COLUMN in generated_columns_requested: @@ -164,149 +216,48 @@ def get_groups_for_subsampling(strains, metadata, group_by=None): for group in unknown_groups: metadata[group] = 'unknown' - # Finally, determine groups. - group_by_strain = dict(zip(metadata.index, metadata[group_by].apply(tuple, axis=1))) - return group_by_strain - - -class PriorityQueue: - """A priority queue implementation that automatically replaces lower priority - items in the heap with incoming higher priority items. - - Examples - -------- - - Add a single record to a heap with a maximum of 2 records. - - >>> queue = PriorityQueue(max_size=2) - >>> queue.add({"strain": "strain1"}, 0.5) - 1 - - Add another record with a higher priority. The queue should be at its maximum - size. - - >>> queue.add({"strain": "strain2"}, 1.0) - 2 - >>> queue.heap - [(0.5, 0, {'strain': 'strain1'}), (1.0, 1, {'strain': 'strain2'})] - >>> list(queue.get_items()) - [{'strain': 'strain1'}, {'strain': 'strain2'}] + return metadata - Add a higher priority record that causes the queue to exceed its maximum - size. The resulting queue should contain the two highest priority records - after the lowest priority record is removed. - >>> queue.add({"strain": "strain3"}, 2.0) - 2 - >>> list(queue.get_items()) - [{'strain': 'strain2'}, {'strain': 'strain3'}] - - Add a record with the same priority as another record, forcing the duplicate - to be resolved by removing the oldest entry. - - >>> queue.add({"strain": "strain4"}, 1.0) - 2 - >>> list(queue.get_items()) - [{'strain': 'strain4'}, {'strain': 'strain3'}] - - """ - def __init__(self, max_size): - """Create a fixed size heap (priority queue) - - """ - self.max_size = max_size - self.heap = [] - self.counter = itertools.count() - - def add(self, item, priority): - """Add an item to the queue with a given priority. - - If adding the item causes the queue to exceed its maximum size, replace - the lowest priority item with the given item. The queue stores items - with an additional heap id value (a count) to resolve ties between items - with equal priority (favoring the most recently added item). - - """ - heap_id = next(self.counter) - - if len(self.heap) >= self.max_size: - heapq.heappushpop(self.heap, (priority, heap_id, item)) - else: - heapq.heappush(self.heap, (priority, heap_id, item)) - - return len(self.heap) - - def get_items(self): - """Return each item in the queue in order. - - Yields - ------ - Any - Item stored in the queue. - - """ - for priority, heap_id, item in self.heap: - yield item - - -def create_queues_by_group(groups, max_size, max_attempts=100, random_seed=None): - """Create a dictionary of priority queues per group for the given maximum size. +def get_group_size_limits(number_of_groups: int, max_size, max_attempts = 100, random_seed = None): + """Return a list of group size limits. When the maximum size is fractional, probabilistically sample the maximum size from a Poisson distribution. Make at least the given number of maximum - attempts to create queues for which the sum of their maximum sizes is + attempts to create groups for which the sum of their maximum sizes is greater than zero. - Examples - -------- - - Create queues for two groups with a fixed maximum size. - - >>> groups = ("2015", "2016") - >>> queues = create_queues_by_group(groups, 2) - >>> sum(queue.max_size for queue in queues.values()) - 4 - - Create queues for two groups with a fractional maximum size. Their total max - size should still be an integer value greater than zero. - - >>> seed = 314159 - >>> queues = create_queues_by_group(groups, 0.1, random_seed=seed) - >>> int(sum(queue.max_size for queue in queues.values())) > 0 - True - - A subsequent run of this function with the same groups and random seed - should produce the same queues and queue sizes. - - >>> more_queues = create_queues_by_group(groups, 0.1, random_seed=seed) - >>> [queue.max_size for queue in queues.values()] == [queue.max_size for queue in more_queues.values()] - True - + Parameters + ---------- + number_of_groups : int + The number of groups. + max_size : int | float + Maximum size of a group. + max_attempts : int + Maximum number of attempts for creating group sizes. + random_seed + Seed value for np.random.default_rng for reproducible randomness. """ - queues_by_group = {} + sizes = None total_max_size = 0 attempts = 0 - if max_size < 1.0: - random_generator = np.random.default_rng(random_seed) + # If max_size is not fractional, use it as the limit for all groups. + if int(max_size) == max_size: + return np.full(number_of_groups, max_size) # For small fractional maximum sizes, it is possible to randomly select - # maximum queue sizes that all equal zero. When this happens, filtering - # fails unexpectedly. We make multiple attempts to create queues with - # maximum sizes greater than zero for at least one queue. + # maximum sizes that all equal zero. When this happens, filtering + # fails unexpectedly. We make multiple attempts to create sizes with + # maximum sizes greater than zero for at least one group. while total_max_size == 0 and attempts < max_attempts: - for group in sorted(groups): - if max_size < 1.0: - queue_max_size = random_generator.poisson(max_size) - else: - queue_max_size = max_size - - queues_by_group[group] = PriorityQueue(queue_max_size) - - total_max_size = sum(queue.max_size for queue in queues_by_group.values()) + sizes = np.random.default_rng(random_seed).poisson(max_size, size=number_of_groups) + total_max_size = sum(sizes) attempts += 1 - return queues_by_group + assert sizes is not None + + return sizes def calculate_sequences_per_group(target_max_value, group_sizes, allow_probabilistic=True): diff --git a/tests/filter/test_subsample.py b/tests/filter/test_subsample.py index d454d0e53..7a6ef9276 100644 --- a/tests/filter/test_subsample.py +++ b/tests/filter/test_subsample.py @@ -37,7 +37,7 @@ class TestFilterGroupBy: def test_filter_groupby_strain_subset(self, valid_metadata: pd.DataFrame): metadata = valid_metadata.copy() strains = ['SEQ_1', 'SEQ_3', 'SEQ_5'] - group_by_strain = augur.filter.subsample.get_groups_for_subsampling(strains, metadata) + group_by_strain = augur.filter.subsample.enrich_metadata(strains, metadata) assert group_by_strain == { 'SEQ_1': ('_dummy',), 'SEQ_3': ('_dummy',), @@ -47,7 +47,7 @@ def test_filter_groupby_strain_subset(self, valid_metadata: pd.DataFrame): def test_filter_groupby_dummy(self, valid_metadata: pd.DataFrame): metadata = valid_metadata.copy() strains = metadata.index.tolist() - group_by_strain = augur.filter.subsample.get_groups_for_subsampling(strains, metadata) + group_by_strain = augur.filter.subsample.enrich_metadata(strains, metadata) assert group_by_strain == { 'SEQ_1': ('_dummy',), 'SEQ_2': ('_dummy',), @@ -61,14 +61,14 @@ def test_filter_groupby_invalid_error(self, valid_metadata: pd.DataFrame): metadata = valid_metadata.copy() strains = metadata.index.tolist() with pytest.raises(AugurError) as e_info: - augur.filter.subsample.get_groups_for_subsampling(strains, metadata, group_by=groups) + augur.filter.subsample.enrich_metadata(strains, metadata, group_by=groups) assert str(e_info.value) == "The specified group-by categories (['invalid']) were not found." def test_filter_groupby_invalid_warn(self, valid_metadata: pd.DataFrame, capsys): groups = ['country', 'year', 'month', 'invalid'] metadata = valid_metadata.copy() strains = metadata.index.tolist() - group_by_strain = augur.filter.subsample.get_groups_for_subsampling(strains, metadata, group_by=groups) + group_by_strain = augur.filter.subsample.enrich_metadata(strains, metadata, group_by=groups) assert group_by_strain == { 'SEQ_1': ('A', 2020, (2020, 1), 'unknown'), 'SEQ_2': ('A', 2020, (2020, 2), 'unknown'), @@ -85,7 +85,7 @@ def test_filter_groupby_missing_year_error(self, valid_metadata: pd.DataFrame): metadata = metadata.drop('date', axis='columns') strains = metadata.index.tolist() with pytest.raises(AugurError) as e_info: - augur.filter.subsample.get_groups_for_subsampling(strains, metadata, group_by=groups) + augur.filter.subsample.enrich_metadata(strains, metadata, group_by=groups) assert str(e_info.value) == "The specified group-by categories (['year']) were not found. Note that using any of ['month', 'week', 'year'] requires a column called 'date'." def test_filter_groupby_missing_month_error(self, valid_metadata: pd.DataFrame): @@ -94,7 +94,7 @@ def test_filter_groupby_missing_month_error(self, valid_metadata: pd.DataFrame): metadata = metadata.drop('date', axis='columns') strains = metadata.index.tolist() with pytest.raises(AugurError) as e_info: - augur.filter.subsample.get_groups_for_subsampling(strains, metadata, group_by=groups) + augur.filter.subsample.enrich_metadata(strains, metadata, group_by=groups) assert str(e_info.value) == "The specified group-by categories (['month']) were not found. Note that using any of ['month', 'week', 'year'] requires a column called 'date'." def test_filter_groupby_missing_year_and_month_error(self, valid_metadata: pd.DataFrame): @@ -103,7 +103,7 @@ def test_filter_groupby_missing_year_and_month_error(self, valid_metadata: pd.Da metadata = metadata.drop('date', axis='columns') strains = metadata.index.tolist() with pytest.raises(AugurError) as e_info: - augur.filter.subsample.get_groups_for_subsampling(strains, metadata, group_by=groups) + augur.filter.subsample.enrich_metadata(strains, metadata, group_by=groups) assert str(e_info.value) == "The specified group-by categories (['year', 'month']) were not found. Note that using any of ['month', 'week', 'year'] requires a column called 'date'." def test_filter_groupby_missing_date_warn(self, valid_metadata: pd.DataFrame, capsys): @@ -111,7 +111,7 @@ def test_filter_groupby_missing_date_warn(self, valid_metadata: pd.DataFrame, ca metadata = valid_metadata.copy() metadata = metadata.drop('date', axis='columns') strains = metadata.index.tolist() - group_by_strain = augur.filter.subsample.get_groups_for_subsampling(strains, metadata, group_by=groups) + group_by_strain = augur.filter.subsample.enrich_metadata(strains, metadata, group_by=groups) assert group_by_strain == { 'SEQ_1': ('A', 'unknown', 'unknown'), 'SEQ_2': ('A', 'unknown', 'unknown'), @@ -126,7 +126,7 @@ def test_filter_groupby_no_strains(self, valid_metadata: pd.DataFrame): groups = ['country', 'year', 'month'] metadata = valid_metadata.copy() strains = [] - group_by_strain = augur.filter.subsample.get_groups_for_subsampling(strains, metadata, group_by=groups) + group_by_strain = augur.filter.subsample.enrich_metadata(strains, metadata, group_by=groups) assert group_by_strain == {} def test_filter_groupby_only_year_provided(self, valid_metadata: pd.DataFrame): @@ -134,7 +134,7 @@ def test_filter_groupby_only_year_provided(self, valid_metadata: pd.DataFrame): metadata = valid_metadata.copy() metadata['date'] = '2020' strains = metadata.index.tolist() - group_by_strain = augur.filter.subsample.get_groups_for_subsampling(strains, metadata, group_by=groups) + group_by_strain = augur.filter.subsample.enrich_metadata(strains, metadata, group_by=groups) assert group_by_strain == { 'SEQ_1': ('A', 2020), 'SEQ_2': ('A', 2020), @@ -148,7 +148,7 @@ def test_filter_groupby_only_year_month_provided(self, valid_metadata: pd.DataFr metadata = valid_metadata.copy() metadata['date'] = '2020-01' strains = metadata.index.tolist() - group_by_strain = augur.filter.subsample.get_groups_for_subsampling(strains, metadata, group_by=groups) + group_by_strain = augur.filter.subsample.enrich_metadata(strains, metadata, group_by=groups) assert group_by_strain == { 'SEQ_1': ('A', 2020, (2020, 1)), 'SEQ_2': ('A', 2020, (2020, 1)), From 850d9083c24a7730dba57ff33b71b83875f7ab59 Mon Sep 17 00:00:00 2001 From: Victor Lin <13424970+victorlin@users.noreply.github.com> Date: Wed, 6 Mar 2024 20:07:09 -0800 Subject: [PATCH 5/5] =?UTF-8?q?=F0=9F=9A=A7=20Deprecate=20augur=20filter?= =?UTF-8?q?=20--metadata-chunk-size?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- DEPRECATED.md | 6 ++++++ augur/filter/__init__.py | 2 +- augur/filter/validate_arguments.py | 4 ++++ 3 files changed, 11 insertions(+), 1 deletion(-) diff --git a/DEPRECATED.md b/DEPRECATED.md index 772b428c0..c4064cc76 100644 --- a/DEPRECATED.md +++ b/DEPRECATED.md @@ -4,6 +4,12 @@ These features are deprecated, which means they are no longer maintained and will go away in a future major version of Augur. They are currently still available for backwards compatibility, but should not be used in new code. +## `augur filter --metadata-chunk-size` + +*Deprecated in version X.X.X. Planned for removal in September 2024 or after.* + +FIXME: add description here + ## `augur parse` preference of `name` over `strain` as the sequence ID field *Deprecated in February 2024. Planned to be reordered June 2024 or after.* diff --git a/augur/filter/__init__.py b/augur/filter/__init__.py index 190465dbe..f3cbc674b 100644 --- a/augur/filter/__init__.py +++ b/augur/filter/__init__.py @@ -18,7 +18,7 @@ def register_arguments(parser): input_group.add_argument('--metadata', required=True, metavar="FILE", help="sequence metadata") input_group.add_argument('--sequences', '-s', help="sequences in FASTA or VCF format") input_group.add_argument('--sequence-index', help="sequence composition report generated by augur index. If not provided, an index will be created on the fly.") - input_group.add_argument('--metadata-chunk-size', type=int, default=100000, help="maximum number of metadata records to read into memory at a time. Increasing this number can speed up filtering at the cost of more memory used.") + input_group.add_argument('--metadata-chunk-size', help="[DEPRECATED] Previously used to specify maximum number of metadata records to read into memory at a time. This no longer has an effect.") input_group.add_argument('--metadata-id-columns', default=DEFAULT_ID_COLUMNS, nargs="+", help="names of possible metadata columns containing identifier information, ordered by priority. Only one ID column will be inferred.") input_group.add_argument('--metadata-delimiters', default=DEFAULT_DELIMITERS, nargs="+", help="delimiters to accept when reading a metadata file. Only one delimiter will be inferred.") diff --git a/augur/filter/validate_arguments.py b/augur/filter/validate_arguments.py index 866989303..08806f66d 100644 --- a/augur/filter/validate_arguments.py +++ b/augur/filter/validate_arguments.py @@ -1,4 +1,5 @@ from augur.errors import AugurError +from augur.io.print import print_err from augur.io.vcf import is_vcf as filename_is_vcf @@ -16,6 +17,9 @@ def validate_arguments(args): args : argparse.Namespace Parsed arguments from argparse """ + if args.metadata_chunk_size: + print_err("WARNING: --metadata-chunk-size is no longer necessary and will be removed in a future version.") + # Don't allow sequence output when no sequence input is provided. if args.output and not args.sequences: raise AugurError("You need to provide sequences to output sequences.")