diff --git a/Cargo.lock b/Cargo.lock index 8cdefb3..73568a3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -861,6 +861,7 @@ dependencies = [ "indexmap", "indicatif", "itertools", + "lazy_static", "noodles", "num-format", "plotly", diff --git a/Cargo.toml b/Cargo.toml index b06c9ee..7b900d1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,6 +19,7 @@ git-testament = "0.2.1" indexmap = "1.9.1" indicatif = "0.16.2" itertools = "0.10.5" +lazy_static = "1.4.0" noodles = { version = "0.34.0", features = [ "async", "bam", diff --git a/src/convert/bam.rs b/src/convert/bam.rs index 59068dd..153c52e 100644 --- a/src/convert/bam.rs +++ b/src/convert/bam.rs @@ -45,7 +45,7 @@ pub async fn to_sam_async( .await .with_context(|| "writing SAM header")?; - let mut counter = RecordCounter::new(); + let mut counter = RecordCounter::default(); let mut record = Record::default(); // (4) Write each record in the BAM file to the SAM file. @@ -131,7 +131,7 @@ pub async fn to_cram_async( .await .with_context(|| "writing CRAM file header")?; - let mut counter = RecordCounter::new(); + let mut counter = RecordCounter::default(); let mut record = Record::default(); // (6) Write each record in the BAM file to the CRAM file. diff --git a/src/convert/command.rs b/src/convert/command.rs index ea20904..76a7a43 100644 --- a/src/convert/command.rs +++ b/src/convert/command.rs @@ -31,8 +31,13 @@ pub struct ConvertArgs { to: PathBuf, /// Number of records to process before exiting the conversion. - #[arg(short = 'n', long, value_name = "USIZE")] - num_records: Option, + #[arg( + short, + long, + default_value_t, + value_name = "'all' or a positive, non-zero integer" + )] + num_records: NumberOfRecords, /// If available, the FASTA reference file used to generate the file. #[arg(short, long)] @@ -91,7 +96,7 @@ pub fn convert(args: ConvertArgs) -> anyhow::Result<()> { // Number of Records // //===================// - let max_records = NumberOfRecords::from(args.num_records); + let max_records = args.num_records; //==========================// // Bioinformatics File Pair // diff --git a/src/convert/cram.rs b/src/convert/cram.rs index f9045e9..6679cc1 100644 --- a/src/convert/cram.rs +++ b/src/convert/cram.rs @@ -52,7 +52,7 @@ pub async fn to_sam_async( .with_context(|| "writing SAM header")?; // (5) Write each record in the CRAM file to the SAM file. - let mut counter = RecordCounter::new(); + let mut counter = RecordCounter::default(); let mut records = reader.records(&repository, &header.parsed); while let Some(record) = records @@ -125,7 +125,7 @@ pub async fn to_bam_async( .with_context(|| "writing BAM reference sequences")?; // (6) Write each record in the CRAM file to the BAM file. - let mut counter = RecordCounter::new(); + let mut counter = RecordCounter::default(); let mut records = reader.records(&repository, &header.parsed); while let Some(record) = records diff --git a/src/convert/sam.rs b/src/convert/sam.rs index 697e52b..ea0bddb 100644 --- a/src/convert/sam.rs +++ b/src/convert/sam.rs @@ -58,7 +58,7 @@ pub async fn to_bam_async( .await .with_context(|| "writing BAM reference sequences")?; - let mut counter = RecordCounter::new(); + let mut counter = RecordCounter::default(); let mut record = Record::default(); // (5) Write each record in the BAM file to the SAM file. @@ -151,7 +151,7 @@ pub async fn to_cram_async( .await .with_context(|| "writing CRAM file header")?; - let mut counter = RecordCounter::new(); + let mut counter = RecordCounter::default(); let mut record = Record::default(); // (6) Write each record in the SAM file to the CRAM file. diff --git a/src/derive.rs b/src/derive.rs index 4c16afc..6a28e5f 100644 --- a/src/derive.rs +++ b/src/derive.rs @@ -1,4 +1,9 @@ //! Functionality related to the `ngs derive` subcommand. pub mod command; +pub mod encoding; +pub mod endedness; pub mod instrument; +pub mod junction_annotation; +pub mod readlen; +pub mod strandedness; diff --git a/src/derive/command.rs b/src/derive/command.rs index c372eb4..4b42d0d 100644 --- a/src/derive/command.rs +++ b/src/derive/command.rs @@ -1,6 +1,11 @@ //! Functionality related to the `ngs derive` subcommand itself. +pub mod encoding; +pub mod endedness; pub mod instrument; +pub mod junction_annotation; +pub mod readlen; +pub mod strandedness; use clap::Args; use clap::Subcommand; @@ -20,6 +25,23 @@ pub struct DeriveArgs { /// All possible subcommands for `ngs derive`. #[derive(Subcommand)] pub enum DeriveSubcommand { + /// Derives the quality score encoding used to produce the file. + Encoding(self::encoding::DeriveEncodingArgs), + + /// Derives the endedness of the file. + Endedness(self::endedness::DeriveEndednessArgs), + /// Derives the instrument used to produce the file. Instrument(self::instrument::DeriveInstrumentArgs), + + /// Derives the read length of the file. + Readlen(self::readlen::DeriveReadlenArgs), + + /// Derives the strandedness of the RNA-Seq file. + Strandedness(self::strandedness::DeriveStrandednessArgs), + + /// Annotates junctions in the file. + /// Note that, technically, this command doesn't derive anything—it will moved in the future to a better home. + /// convenience. + JunctionAnnotation(self::junction_annotation::JunctionAnnotationArgs), } diff --git a/src/derive/command/encoding.rs b/src/derive/command/encoding.rs new file mode 100644 index 0000000..ceaef6b --- /dev/null +++ b/src/derive/command/encoding.rs @@ -0,0 +1,78 @@ +//! Functionality relating to the `ngs derive encoding` subcommand itself. + +use anyhow::{Context, Ok}; +use clap::Args; +use noodles::bam; +use num_format::{Locale, ToFormattedString}; +use std::collections::HashSet; +use std::io::BufReader; +use std::path::PathBuf; +use tracing::info; + +use crate::derive::encoding::compute; +use crate::utils::args::NumberOfRecords; +use crate::utils::display::RecordCounter; + +/// Clap arguments for the `ngs derive encoding` subcommand. +#[derive(Args)] +pub struct DeriveEncodingArgs { + /// Source BAM. + #[arg(value_name = "BAM")] + src: PathBuf, + + /// Examine the first `n` records in the file. + #[arg( + short, + long, + default_value_t, + value_name = "'all' or a positive, non-zero integer" + )] + num_records: NumberOfRecords, +} + +/// Main function for the `ngs derive encoding` subcommand. +pub fn derive(args: DeriveEncodingArgs) -> anyhow::Result<()> { + info!("Starting derive encoding subcommand."); + + let file = std::fs::File::open(args.src); + let reader = file + .map(BufReader::new) + .with_context(|| "opening BAM file")?; + let mut reader = bam::Reader::new(reader); + let _header: String = reader.read_header()?.parse()?; + reader.read_reference_sequences()?; + + let mut score_set: HashSet = HashSet::new(); + + // (1) Collect quality scores from reads within the + // file. Support for sampling only a portion of the reads is provided. + let mut counter = RecordCounter::default(); + for result in reader.lazy_records() { + let record = result?; + + for i in 0..record.quality_scores().len() { + let score = record.quality_scores().as_ref()[i]; + score_set.insert(score); + } + + counter.inc(); + if counter.time_to_break(&args.num_records) { + break; + } + } + + info!( + "Processed {} records.", + counter.get().to_formatted_string(&Locale::en) + ); + + // (2) Derive encoding from the observed quality scores + let result = compute::predict(score_set)?; + + // (3) Print the output to stdout as JSON (more support for different output + // types may be added in the future, but for now, only JSON). + let output = serde_json::to_string_pretty(&result).unwrap(); + println!("{}", output); + + Ok(()) +} diff --git a/src/derive/command/endedness.rs b/src/derive/command/endedness.rs new file mode 100644 index 0000000..b7d38c6 --- /dev/null +++ b/src/derive/command/endedness.rs @@ -0,0 +1,164 @@ +//! Functionality relating to the `ngs derive endedness` subcommand itself. + +use anyhow::Context; +use clap::Args; +use num_format::{Locale, ToFormattedString}; +use std::collections::{HashMap, HashSet}; +use std::path::PathBuf; +use std::sync::Arc; +use tracing::{info, trace}; + +use crate::derive::endedness::compute; +use crate::derive::endedness::compute::OrderingFlagsCounts; +use crate::utils::args::arg_in_range as deviance_in_range; +use crate::utils::args::NumberOfRecords; +use crate::utils::display::RecordCounter; +use crate::utils::formats::bam::ParsedBAMFile; +use crate::utils::formats::utils::IndexCheck; +use crate::utils::read_groups::{get_read_group, validate_read_group_info, ReadGroupPtr}; + +/// Clap arguments for the `ngs derive endedness` subcommand. +#[derive(Args)] +pub struct DeriveEndednessArgs { + /// Source BAM. + #[arg(value_name = "BAM")] + src: PathBuf, + + /// Examine the first `n` records in the file. + #[arg( + short, + long, + default_value_t, + value_name = "'all' or a positive, non-zero integer" + )] + num_records: NumberOfRecords, + + /// Distance from 0.5 split between number of f+l- reads and f-l+ reads + /// allowed to be called 'Paired-End'. The default value of `0.0` is only appropriate + /// if the whole file is being processed. + #[arg(long, value_name = "F64", default_value = "0.0")] + paired_deviance: f64, + + /// Calculate and output Reads-Per-Template. This will produce a more + /// sophisticated estimate for endedness, but uses substantially more memory. + #[arg(long, default_value = "false")] + calculate_reads_per_template: bool, + + /// Round RPT to the nearest INT before comparing to expected values. + /// Appropriate if using `-n` > 0. Unrounded value is reported in output. + #[arg(long, default_value = "false")] + round_reads_per_template: bool, +} + +/// Main function for the `ngs derive endedness` subcommand. +pub fn derive(args: DeriveEndednessArgs) -> anyhow::Result<()> { + // (0) Parse arguments needed for subcommand. + let paired_deviance = deviance_in_range(args.paired_deviance, 0.0..=0.5) + .with_context(|| "Paired deviance is not within acceptable range")?; + + info!("Starting derive endedness subcommand."); + + let mut found_rgs = HashSet::new(); + + let mut ordering_flags: HashMap = HashMap::new(); + + // only used if args.calc_rpt is true + let mut read_names: Option>> = None; + + let ParsedBAMFile { + mut reader, header, .. + } = crate::utils::formats::bam::open_and_parse(args.src, IndexCheck::None)?; + + // (1) Collect ordering flags (and QNAMEs) from reads within the + // file. Support for sampling only a portion of the reads is provided. + let mut counter = RecordCounter::default(); + for result in reader.records(&header.parsed) { + let record = result?; + + // Only count primary alignments and unmapped reads. + if (record.flags().is_secondary() || record.flags().is_supplementary()) + && !record.flags().is_unmapped() + { + continue; + } + + let read_group = get_read_group(&record, Some(&mut found_rgs)); + + if args.calculate_reads_per_template { + let read_name_map = read_names.get_or_insert_with(HashMap::new); + match record.read_name() { + Some(rn) => { + let rn = rn.to_string(); + let rg_vec = read_name_map.get_mut(&rn); + + match rg_vec { + Some(rg_vec) => { + rg_vec.push(Arc::clone(&read_group)); + } + None => { + read_name_map.insert(rn, vec![(Arc::clone(&read_group))]); + } + } + } + None => { + trace!("Could not parse a QNAME from a read in the file."); + trace!("Skipping this read and proceeding."); + continue; + } + } + } + + match ( + record.flags().is_segmented(), + record.flags().is_first_segment(), + record.flags().is_last_segment(), + ) { + (false, _, _) => { + ordering_flags.entry(read_group).or_default().unsegmented += 1; + } + (true, true, false) => { + ordering_flags.entry(read_group).or_default().first += 1; + } + (true, false, true) => { + ordering_flags.entry(read_group).or_default().last += 1; + } + (true, true, true) => { + ordering_flags.entry(read_group).or_default().both += 1; + } + (true, false, false) => { + ordering_flags.entry(read_group).or_default().neither += 1; + } + } + + counter.inc(); + if counter.time_to_break(&args.num_records) { + break; + } + } + + info!( + "Processed {} records.", + counter.get().to_formatted_string(&Locale::en) + ); + + // (2) Validate the read group information. + let rgs_in_header_not_records = validate_read_group_info(&found_rgs, &header.parsed); + for rg_id in rgs_in_header_not_records { + ordering_flags.insert(Arc::new(rg_id), OrderingFlagsCounts::new()); + } + + // (3) Derive the endedness based on the ordering flags gathered. + let result = compute::predict( + ordering_flags, + read_names, + paired_deviance, + args.round_reads_per_template, + ); + + // (4) Print the output to stdout as JSON (more support for different output + // types may be added in the future, but for now, only JSON). + let output = serde_json::to_string_pretty(&result).unwrap(); + println!("{}", output); + + anyhow::Ok(()) +} diff --git a/src/derive/command/instrument.rs b/src/derive/command/instrument.rs index c36ec19..4ef85ab 100644 --- a/src/derive/command/instrument.rs +++ b/src/derive/command/instrument.rs @@ -1,17 +1,19 @@ //! Functionality relating to the `ngs derive instrument` subcommand itself. -use anyhow::bail; -use std::collections::HashSet; -use std::path::PathBuf; -use std::thread; - use clap::Args; +use num_format::{Locale, ToFormattedString}; +use std::collections::{HashMap, HashSet}; +use std::path::PathBuf; +use std::sync::Arc; use tracing::info; use crate::derive::instrument::compute; use crate::derive::instrument::reads::IlluminaReadName; +use crate::utils::args::NumberOfRecords; +use crate::utils::display::RecordCounter; use crate::utils::formats::bam::ParsedBAMFile; use crate::utils::formats::utils::IndexCheck; +use crate::utils::read_groups::{get_read_group, validate_read_group_info, ReadGroupPtr}; /// Clap arguments for the `ngs derive instrument` subcommand. #[derive(Args)] @@ -20,91 +22,91 @@ pub struct DeriveInstrumentArgs { #[arg(value_name = "BAM")] src: PathBuf, - /// Only examine the first n records in the file. - #[arg(short, long, value_name = "USIZE")] - num_records: Option, - - /// Use a specific number of threads. - #[arg(short, long, value_name = "USIZE")] - threads: Option, + /// Examine the first `n` records in the file. + #[arg( + short, + long, + default_value = "10000000", + value_name = "'all' or a positive, non-zero integer" + )] + num_records: NumberOfRecords, } -/// Entrypoint for the `ngs derive instrument` subcommand. +/// Main function for the `ngs derive instrument` subcommand. pub fn derive(args: DeriveInstrumentArgs) -> anyhow::Result<()> { - let first_n_reads: Option = args.num_records; - let threads = match args.threads { - Some(t) => t, - None => thread::available_parallelism().map(usize::from)?, - }; - - info!( - "Starting derive instrument subcommand with {} threads.", - threads - ); - - let rt = tokio::runtime::Builder::new_multi_thread() - .worker_threads(threads) - .build()?; - - rt.block_on(app(args.src, first_n_reads)) -} + let src = args.src; + let mut instrument_names: HashMap> = HashMap::new(); + let mut flowcell_names: HashMap> = HashMap::new(); + let mut metrics = compute::RecordMetrics::default(); + let mut found_rgs = HashSet::new(); -/// Main function for the `ngs derive instrument` subcommand. -async fn app(src: PathBuf, first_n_reads: Option) -> anyhow::Result<()> { - let mut instrument_names = HashSet::new(); - let mut flowcell_names = HashSet::new(); + info!("Starting derive instrument subcommand."); let ParsedBAMFile { mut reader, header, .. - } = crate::utils::formats::bam::open_and_parse(src, IndexCheck::Full)?; + } = crate::utils::formats::bam::open_and_parse(src, IndexCheck::None)?; // (1) Collect instrument names and flowcell names from reads within the // file. Support for sampling only a portion of the reads is provided. - let mut samples = 0; - let mut sample_max = 0; - - if let Some(s) = first_n_reads { - sample_max = s; - } - + let mut counter = RecordCounter::default(); for result in reader.records(&header.parsed) { let record = result?; + let read_group = get_read_group(&record, Some(&mut found_rgs)); if let Some(read_name) = record.read_name() { let name: &str = read_name.as_ref(); match name.parse::() { Ok(read) => { - instrument_names.insert(read.instrument_name); + instrument_names + .entry(read_group.clone()) + .or_default() + .insert(read.instrument_name); + metrics.found_instrument_name += 1; + // Get or init the flowcell set for this read group. + let fc_entry = flowcell_names.entry(read_group).or_default(); if let Some(fc) = read.flowcell { - flowcell_names.insert(fc); + fc_entry.insert(fc); + metrics.found_flowcell_name += 1; } } Err(_) => { - bail!( - "Could not parse Illumina-formatted query names for read: {}", - name - ); + metrics.bad_read_name += 1; } } + } else { + metrics.bad_read_name += 1; } - if sample_max > 0 { - samples += 1; - if samples > sample_max { - break; - } + counter.inc(); + if counter.time_to_break(&args.num_records) { + break; } } - // (2) Derive the predict instrument results based on these detected + info!( + "Processed {} records.", + counter.get().to_formatted_string(&Locale::en) + ); + metrics.total_records = counter.get(); + + // (2) Validate the read group information. + let rgs_in_header_not_records = validate_read_group_info(&found_rgs, &header.parsed); + for rg_id in rgs_in_header_not_records { + let rg_ptr = Arc::new(rg_id); + instrument_names.insert(rg_ptr.clone(), HashSet::new()); + flowcell_names.insert(rg_ptr, HashSet::new()); + } + + // (3) Derive the instrument results based on the detected // instrument names and flowcell names. - let result = compute::predict(instrument_names, flowcell_names); + let mut result = compute::predict(instrument_names, flowcell_names); + result.records = metrics; - // (3) Print the output to stdout as JSON (more support for different output + // (4) Print the output to stdout as JSON (more support for different output // types may be added in the future, but for now, only JSON). let output = serde_json::to_string_pretty(&result).unwrap(); - print!("{}", output); + println!("{}", output); Ok(()) } diff --git a/src/derive/command/junction_annotation.rs b/src/derive/command/junction_annotation.rs new file mode 100644 index 0000000..def6281 --- /dev/null +++ b/src/derive/command/junction_annotation.rs @@ -0,0 +1,139 @@ +//! Functionality relating to the `ngs derive junction-annotation` subcommand itself. + +use anyhow::Context; +use clap::Args; +use noodles::sam::record::MappingQuality; +use num_format::{Locale, ToFormattedString}; +use std::collections::HashMap; +use std::path::PathBuf; +use tracing::{debug, info}; + +use crate::derive::junction_annotation::compute; +use crate::derive::junction_annotation::results::JunctionAnnotationResults; +use crate::utils::display::RecordCounter; +use crate::utils::formats; +use crate::utils::formats::bam::ParsedBAMFile; +use crate::utils::formats::utils::IndexCheck; + +/// Clap arguments for the `ngs derive junction-annotation` subcommand. +#[derive(Args)] +pub struct JunctionAnnotationArgs { + /// Source BAM. + #[arg(value_name = "BAM")] + src: PathBuf, + + /// Features GFF file. + #[arg(short = 'f', long, required = true, value_name = "PATH")] + features_gff: PathBuf, + + /// Name of the exon region feature for the gene model used. + #[arg(long, value_name = "STRING", default_value = "exon")] + exon_feature_name: String, + + /// Minimum intron length to consider. + /// An intron is defined as an `N` CIGAR operation of any length. + #[arg(short = 'i', long, value_name = "USIZE", default_value = "50")] + min_intron_length: usize, + + /// Minimum number of reads supporting a junction to be considered. + #[arg(short = 'r', long, value_name = "USIZE", default_value = "2")] + min_read_support: usize, + + /// Minumum mapping quality for a record to be considered. + /// Default behavior is to ignore MAPQ values, + /// which allows reads with _missing_ MAPQs to be considered. + /// Specify any u8 value (lower than 255) to enable this filter. + /// Some aligners erroneously use 255 as the score for a uniquely mapped read; + /// however, 255 is reserved by the spec for a missing MAPQ value. + /// Therefore BAMs produced by aligners using 255 erroneously + /// are not compatible with setting this option. + #[arg(short, long, value_name = "U8")] + min_mapq: Option, + + /// Do not count supplementary alignments. + #[arg(long)] + no_supplementary: bool, + + /// Do count secondary alignments. + #[arg(long)] + count_secondary: bool, + + /// Do count duplicates. + #[arg(long)] + count_duplicates: bool, +} + +/// Main function for the `ngs derive junction-annotation` subcommand. +pub fn derive(args: JunctionAnnotationArgs) -> anyhow::Result<()> { + info!("Starting derive junction-annotation subcommand."); + + let mut exons = compute::ExonSets { + starts: HashMap::new(), + ends: HashMap::new(), + }; + + // (1) Parse the GFF file and collect all exon features. + debug!("Reading all records in GFF."); + let mut gff = formats::gff::open(&args.features_gff) + .with_context(|| format!("opening GFF file: {}", args.features_gff.display()))?; + + let mut exon_records = Vec::new(); + for result in gff.records() { + let record = result.unwrap(); + if record.ty() != args.exon_feature_name { + continue; + } + exon_records.push(record); + } + debug!("Read {} exon records.", exon_records.len()); + + debug!("Processing GFF exon features."); + for record in &exon_records { + let seq_name = record.reference_sequence_name(); + let start = record.start(); + let end = record.end().checked_add(1).unwrap(); // TODO: why +1? It works. + + exons.starts.entry(seq_name).or_default().insert(start); + exons.ends.entry(seq_name).or_default().insert(end); + } + + debug!("Done reading GFF."); + + // (1.5) Initialize variables (including opening the BAM). + let mut counter = RecordCounter::default(); + let mut results = JunctionAnnotationResults::default(); + let params = compute::JunctionAnnotationParameters { + min_intron_length: args.min_intron_length, + min_read_support: args.min_read_support, + min_mapq: args.min_mapq, + no_supplementary: args.no_supplementary, + count_secondary: args.count_secondary, + count_duplicates: args.count_duplicates, + }; + + let ParsedBAMFile { + mut reader, header, .. + } = formats::bam::open_and_parse(args.src, IndexCheck::None)?; + + // (2) Process each record in the BAM file. + for result in reader.records(&header.parsed) { + let record = result?; + compute::process(&record, &exons, &header.parsed, ¶ms, &mut results)?; + counter.inc(); + } + + info!( + "Processed {} records.", + counter.get().to_formatted_string(&Locale::en) + ); + + // (3) Summarize found junctions. + compute::summarize(&mut results, ¶ms); + + // (4) Print the output to stdout as JSON (more support for different output + // types may be added in the future, but for now, only JSON). + let output = serde_json::to_string_pretty(&results).unwrap(); + println!("{}", output); + + anyhow::Ok(()) +} diff --git a/src/derive/command/readlen.rs b/src/derive/command/readlen.rs new file mode 100644 index 0000000..1cf46d8 --- /dev/null +++ b/src/derive/command/readlen.rs @@ -0,0 +1,96 @@ +//! Functionality relating to the `ngs derive readlen` subcommand itself. + +use anyhow::Context; +use clap::Args; +use num_format::{Locale, ToFormattedString}; +use std::collections::HashMap; +use std::collections::HashSet; +use std::path::PathBuf; +use std::sync::Arc; +use tracing::info; + +use crate::derive::readlen::compute; +use crate::utils::args::arg_in_range as cutoff_in_range; +use crate::utils::args::NumberOfRecords; +use crate::utils::display::RecordCounter; +use crate::utils::formats::bam::ParsedBAMFile; +use crate::utils::formats::utils::IndexCheck; +use crate::utils::read_groups::{get_read_group, validate_read_group_info, ReadGroupPtr}; + +/// Clap arguments for the `ngs derive readlen` subcommand. +#[derive(Args)] +pub struct DeriveReadlenArgs { + // Source BAM. + #[arg(value_name = "BAM")] + src: PathBuf, + + /// Examine the first `n` records in the file. + #[arg( + short, + long, + default_value = "10000000", + value_name = "'all' or a positive, non-zero integer" + )] + num_records: NumberOfRecords, + + /// Majority vote cutoff value as a fraction between [0.0, 1.0]. + #[arg(short, long, value_name = "F64", default_value = "0.7")] + majority_vote_cutoff: f64, +} + +/// Main function for the `ngs derive readlen` subcommand. +pub fn derive(args: DeriveReadlenArgs) -> anyhow::Result<()> { + // (0) Parse arguments needed for subcommand. + let majority_vote_cutoff = cutoff_in_range(args.majority_vote_cutoff, 0.0..=1.0) + .with_context(|| "Majority vote cutoff is not within acceptable range")?; + + let mut read_lengths: HashMap> = HashMap::new(); + let mut found_rgs = HashSet::new(); + + info!("Starting derive readlen subcommand."); + + let ParsedBAMFile { + mut reader, header, .. + } = crate::utils::formats::bam::open_and_parse(args.src, IndexCheck::None)?; + + // (1) Collect read lengths from reads within the + // file. Support for sampling only a portion of the reads is provided. + let mut counter = RecordCounter::default(); + for result in reader.records(&header.parsed) { + let record = result?; + let read_group = get_read_group(&record, Some(&mut found_rgs)); + let len = record.sequence().len(); + + *read_lengths + .entry(read_group) + .or_default() + .entry(len) + .or_default() += 1; + + counter.inc(); + if counter.time_to_break(&args.num_records) { + break; + } + } + + info!( + "Processed {} records.", + counter.get().to_formatted_string(&Locale::en) + ); + + // (2) Validate the read group information. + let rgs_in_header_not_records = validate_read_group_info(&found_rgs, &header.parsed); + for rg_id in rgs_in_header_not_records { + read_lengths.insert(Arc::new(rg_id), HashMap::new()); + } + + // (3) Derive the consensus read length based on the read lengths gathered. + let result = compute::predict(read_lengths, majority_vote_cutoff); + + // (4) Print the output to stdout as JSON (more support for different output + // types may be added in the future, but for now, only JSON). + let output = serde_json::to_string_pretty(&result).unwrap(); + println!("{}", output); + + anyhow::Ok(()) +} diff --git a/src/derive/command/strandedness.rs b/src/derive/command/strandedness.rs new file mode 100644 index 0000000..5ce71d7 --- /dev/null +++ b/src/derive/command/strandedness.rs @@ -0,0 +1,290 @@ +//! Functionality relating to the `ngs derive strandedness` subcommand itself. + +use anyhow::{bail, Context}; +use clap::Args; +use noodles::sam::record::MappingQuality; +use noodles::{bam, gff}; +use rust_lapper::{Interval, Lapper}; +use std::collections::{HashMap, HashSet}; +use std::fs::File; +use std::path::PathBuf; +use tracing::{debug, info}; + +use crate::derive::strandedness::compute::ParsedBAMFile; +use crate::derive::strandedness::{compute, results}; +use crate::utils::formats; +use crate::utils::read_groups::validate_read_group_info; + +/// Clap arguments for the `ngs derive strandedness` subcommand. +#[derive(Args)] +pub struct DeriveStrandednessArgs { + /// Source BAM. + #[arg(value_name = "BAM")] + src: PathBuf, + + /// Features GFF file. + #[arg(short = 'f', long, required = true, value_name = "PATH")] + features_gff: PathBuf, + + /// When inconclusive, the test will repeat until this many tries have been reached. + /// Evidence of previous attempts is saved and reused, + /// leading to a larger sample size with multiple attempts. + #[arg(long, value_name = "USIZE", default_value = "3")] + max_tries: usize, + + /// Filter any genes that don't have at least `m` reads. + #[arg(short = 'm', long, value_name = "USIZE", default_value = "10")] + min_reads_per_gene: usize, + + /// How many genes to use as evidence in strandendess classification per try. + /// This does not count genes which fail filtering + /// due to `--min-reads-per-gene` or are discarded + /// due to problematic Strand information in the GFF. + /// Problematic Strand information is caused by contradictions between + /// gene entries and overlapping exon entries. + #[arg(short = 'n', long, value_name = "USIZE", default_value = "1000")] + num_genes: usize, + + /// Minumum mapping quality for a record to be considered. + /// Default behavior is to ignore MAPQ values, + /// which allows reads with _missing_ MAPQs to be considered. + /// Specify any u8 value (lower than 255) to enable this filter. + /// Some aligners erroneously use 255 as the score for a uniquely mapped read; + /// however, 255 is reserved by the spec for a missing MAPQ value. + /// Therefore BAMs produced by aligners using 255 erroneously + /// are not compatible with setting this option. + #[arg(long, value_name = "U8")] + min_mapq: Option, + + /// Consider all genes, not just protein coding genes. + #[arg(long)] + all_genes: bool, + + /// Name of the gene region feature for the gene model used. + #[arg(long, value_name = "STRING", default_value = "gene")] + gene_feature_name: String, + + /// Name of the exon region feature for the gene model used. + #[arg(long, value_name = "STRING", default_value = "exon")] + exon_feature_name: String, + + /// Do not count supplementary alignments. + #[arg(long)] + no_supplementary: bool, + + /// Do count secondary alignments. + #[arg(long)] + count_secondary: bool, + + /// Do count duplicates. + #[arg(long)] + count_duplicates: bool, + + /// Do count QC failed reads. + #[arg(long)] + count_qc_failed: bool, + + /// At most, evaluate this many genes + /// per try. Default is 10 * --num-genes. + #[arg(long, value_name = "USIZE")] + max_genes_per_try: Option, +} + +/// Main function for the `ngs derive strandedness` subcommand. +pub fn derive(args: DeriveStrandednessArgs) -> anyhow::Result<()> { + info!("Starting derive strandedness subcommand."); + + // (1) Parse the GFF file and collect all gene and exon features. + debug!("Reading all records in GFF."); + let mut gff = formats::gff::open(&args.features_gff) + .with_context(|| format!("opening GFF file: {}", args.features_gff.display()))?; + + let mut gene_records = Vec::new(); + let mut exon_records = Vec::new(); + let mut gene_metrics = results::GeneRecordMetrics::default(); + let mut exon_metrics = results::ExonRecordMetrics::default(); + for result in gff.records() { + let record = result.unwrap(); + if record.ty() == args.gene_feature_name { + gene_metrics.total += 1; + + // If --all-genes is set, don't check the gene type or biotype. + // Otherwise, check the gene type or biotype and keep the record if it's protein coding. + // If the record does not have a gene type or biotype, discard it. + let mut keep_record = false; + if !args.all_genes { + let mut gene_type_value = None; + for entry in record.attributes().as_ref() { + gene_type_value = match entry.key() { + "gene_type" => Some(entry.value()), // Gencode + "gene_biotype" => Some(entry.value()), // ENSEMBL + "biotype" => Some(entry.value()), // also ENSEMBL + _ => gene_type_value, + }; + } + if let Some(gene_type_value) = gene_type_value { + if gene_type_value.to_lowercase().contains("protein") { + keep_record = true; + gene_metrics.protein_coding += 1; + } + } + } + if !keep_record { + continue; + } + + // Make sure the gene record has a valid strand. + let gene_strand = record.strand(); + if gene_strand != gff::record::Strand::Forward + && gene_strand != gff::record::Strand::Reverse + { + gene_metrics.bad_strand += 1; + continue; + } + + gene_records.push(record); + } else if record.ty() == args.exon_feature_name { + exon_metrics.total += 1; + exon_records.push(record); + } + } + if gene_records.is_empty() { + bail!("No gene records matched criteria. Check your GFF file and `--gene-feature-name` and `--all-genes` options."); + } + if exon_records.is_empty() { + bail!("No exon records matched criteria. Check your GFF file and `--exon-feature-name` option."); + } + debug!( + "Found {} gene records and {} exon records.", + gene_records.len(), + exon_records.len() + ); + + // (2) Parse exon features into proper data structure. + debug!("Tabulating GFF exon features."); + + let mut exon_intervals: HashMap<&str, Vec>> = HashMap::new(); + for record in &exon_records { + let seq_name = record.reference_sequence_name(); + let start: usize = record.start().into(); + let stop: usize = record.end().into(); + let strand = record.strand(); + + if strand != gff::record::Strand::Forward && strand != gff::record::Strand::Reverse { + exon_metrics.bad_strand += 1; + continue; + } + let strand = compute::Strand::try_from(strand).unwrap(); // above check guarantees safety + + exon_intervals.entry(seq_name).or_default().push(Interval { + start, + stop, + val: strand, + }); + } + + if exon_metrics.bad_strand == exon_metrics.total { + bail!("All exons were discarded due to bad strand information. Check your GFF file."); + } + debug!( + "{} exons were discarded due to bad strand information.", + exon_metrics.bad_strand + ); + + let mut exons: HashMap<&str, Lapper> = HashMap::new(); + for (seq_name, intervals) in exon_intervals { + exons.insert(seq_name, Lapper::new(intervals)); + } + + debug!("Done reading GFF."); + + // (3) Initialize variables (including opening the BAM). + let mut reader = File::open(&args.src) + .map(bam::Reader::new) + .with_context(|| format!("opening BAM file: {}", args.src.display()))?; + let header = reader.read_header()?.parse()?; + let index = bam::bai::read(args.src.with_extension("bam.bai")).with_context(|| { + format!( + "reading BAM index: {}", + args.src.with_extension("bam.bai").display() + ) + })?; + + let mut parsed_bam = ParsedBAMFile { + reader, + header, + index, + }; + + let max_genes_per_try = args.max_genes_per_try.unwrap_or(args.num_genes * 10); + + let params = compute::StrandednessParams { + num_genes: args.num_genes, + max_genes_per_try, + min_reads_per_gene: args.min_reads_per_gene, + min_mapq: args.min_mapq, + count_qc_failed: args.count_qc_failed, + no_supplementary: args.no_supplementary, + count_secondary: args.count_secondary, + count_duplicates: args.count_duplicates, + }; + + let mut all_counts = compute::AllReadGroupsCounts { + counts: HashMap::new(), + found_rgs: HashSet::new(), + }; + let mut metrics = results::RecordTracker { + genes: gene_metrics, + exons: exon_metrics, + reads: results::ReadRecordMetrics::default(), + }; + let mut result: Option = None; + + // (4) Run the strandedness test. + for try_num in 1..=args.max_tries { + info!("Starting try {} of {}", try_num, args.max_tries); + + let attempt = compute::predict( + &mut parsed_bam, + &mut gene_records, + &exons, + &mut all_counts, + ¶ms, + &mut metrics, + )?; + let success = attempt.succeeded; + result = Some(attempt); + if success { + info!("Strandedness test succeeded."); + break; + } else { + info!("Strandedness test inconclusive."); + } + } + let mut result = result.unwrap(); + + if !result.succeeded { + info!( + "Strandedness test still failed after {} tries.", + args.max_tries + ); + } + + let rgs_in_header_not_found = + validate_read_group_info(&all_counts.found_rgs, &parsed_bam.header); + let mut empty_rg_results = Vec::new(); + for rg in rgs_in_header_not_found { + empty_rg_results.push(compute::predict_strandedness( + &rg, + &compute::Counts::default(), + )); + } + result.read_groups.extend(empty_rg_results); + + // (5) Print the output to stdout as JSON (more support for different output + // types may be added in the future, but for now, only JSON). + let output = serde_json::to_string_pretty(&result).unwrap(); + println!("{}", output); + + anyhow::Ok(()) +} diff --git a/src/derive/encoding.rs b/src/derive/encoding.rs new file mode 100644 index 0000000..a20dc0d --- /dev/null +++ b/src/derive/encoding.rs @@ -0,0 +1,3 @@ +//! Supporting functionality for the `ngs derive encoding` subcommand. + +pub mod compute; diff --git a/src/derive/encoding/compute.rs b/src/derive/encoding/compute.rs new file mode 100644 index 0000000..8063fb7 --- /dev/null +++ b/src/derive/encoding/compute.rs @@ -0,0 +1,326 @@ +//! Module holding the logic for computing the quality score encoding. + +use anyhow::bail; +use serde::Serialize; +use std::collections::HashSet; + +const MAX_VALID_PHRED_SCORE: u8 = 93; +const SANGER_MIN: u8 = 0; +const ILLUMINA_1_0_MIN: u8 = 26; +const ILLUMINA_1_3_MIN: u8 = 31; + +/// Struct holding the final results for an `ngs derive encoding` subcommand +/// call. +#[derive(Debug, Serialize)] +pub struct DerivedEncodingResult { + /// Whether or not the `ngs derive encoding` subcommand succeeded. + pub succeeded: bool, + + /// The detected quality score encoding, if derivable. + pub encoding: Option, + + /// The minimum quality score observed. + pub observed_min: u8, + + /// The maximum quality score observed. + pub observed_max: u8, +} + +impl DerivedEncodingResult { + /// Creates a new [`DerivedEncodingResult`]. + pub fn new( + succeeded: bool, + encoding: Option, + observed_min: u8, + observed_max: u8, + ) -> Self { + DerivedEncodingResult { + succeeded, + encoding, + observed_min, + observed_max, + } + } +} + +/// Main method to evaluate the observed quality scores and +/// return a result for the derived encoding. This may fail, and the +/// resulting [`DerivedEncodingResult`] should be evaluated accordingly. +pub fn predict(score_set: HashSet) -> Result { + if score_set.is_empty() { + bail!("No quality scores were detected in the file."); + } + + let observed_min = *score_set.iter().min().unwrap(); + let observed_max = *score_set.iter().max().unwrap(); + + let mut result = DerivedEncodingResult::new(false, None, observed_min, observed_max); + + if observed_max > MAX_VALID_PHRED_SCORE { + return anyhow::Ok(result); + } + match observed_min { + ILLUMINA_1_3_MIN..=MAX_VALID_PHRED_SCORE => { + result.succeeded = true; + result.encoding = Some("Illumina 1.3".to_string()); + } + ILLUMINA_1_0_MIN..=MAX_VALID_PHRED_SCORE => { + result.succeeded = true; + result.encoding = Some("Illumina 1.0".to_string()); + } + SANGER_MIN..=MAX_VALID_PHRED_SCORE => { + result.succeeded = true; + result.encoding = Some("Sanger/Illumina 1.8".to_string()); + } + _ => unreachable!(), + } + + anyhow::Ok(result) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_predict_illumina_1_3() { + let mut score_set: HashSet = HashSet::new(); + score_set.insert(40); + score_set.insert(41); + score_set.insert(42); + score_set.insert(43); + score_set.insert(44); + score_set.insert(45); + score_set.insert(46); + score_set.insert(47); + score_set.insert(48); + score_set.insert(49); + score_set.insert(50); + score_set.insert(51); + score_set.insert(52); + score_set.insert(53); + score_set.insert(54); + score_set.insert(55); + score_set.insert(56); + score_set.insert(57); + score_set.insert(58); + score_set.insert(59); + score_set.insert(60); + score_set.insert(61); + score_set.insert(62); + score_set.insert(63); + score_set.insert(64); + score_set.insert(65); + score_set.insert(66); + score_set.insert(67); + score_set.insert(68); + score_set.insert(69); + score_set.insert(70); + score_set.insert(71); + score_set.insert(72); + score_set.insert(73); + score_set.insert(74); + score_set.insert(75); + score_set.insert(76); + score_set.insert(77); + score_set.insert(78); + score_set.insert(79); + score_set.insert(80); + score_set.insert(81); + score_set.insert(82); + score_set.insert(83); + score_set.insert(84); + score_set.insert(85); + score_set.insert(86); + score_set.insert(87); + score_set.insert(88); + score_set.insert(89); + score_set.insert(90); + score_set.insert(91); + score_set.insert(92); + score_set.insert(93); + + let result = predict(score_set).unwrap(); + assert!(result.succeeded); + assert_eq!(result.encoding, Some("Illumina 1.3".to_string())); + assert_eq!(result.observed_min, 40); + assert_eq!(result.observed_max, 93); + } + + #[test] + fn test_predict_illumina_1_0() { + let mut score_set: HashSet = HashSet::new(); + score_set.insert(26); + score_set.insert(27); + score_set.insert(28); + score_set.insert(29); + score_set.insert(30); + score_set.insert(31); + score_set.insert(32); + score_set.insert(33); + score_set.insert(34); + score_set.insert(35); + score_set.insert(36); + score_set.insert(37); + score_set.insert(38); + score_set.insert(39); + score_set.insert(40); + score_set.insert(41); + score_set.insert(42); + score_set.insert(43); + score_set.insert(44); + score_set.insert(45); + score_set.insert(46); + score_set.insert(47); + score_set.insert(48); + score_set.insert(49); + score_set.insert(50); + score_set.insert(51); + score_set.insert(52); + score_set.insert(53); + score_set.insert(54); + score_set.insert(55); + score_set.insert(56); + score_set.insert(57); + score_set.insert(58); + score_set.insert(59); + score_set.insert(60); + score_set.insert(61); + score_set.insert(62); + score_set.insert(63); + score_set.insert(64); + score_set.insert(65); + score_set.insert(66); + score_set.insert(67); + score_set.insert(68); + score_set.insert(69); + score_set.insert(70); + score_set.insert(71); + score_set.insert(72); + score_set.insert(73); + score_set.insert(74); + score_set.insert(75); + score_set.insert(76); + score_set.insert(77); + score_set.insert(78); + score_set.insert(79); + score_set.insert(80); + score_set.insert(81); + score_set.insert(82); + score_set.insert(83); + score_set.insert(84); + score_set.insert(85); + score_set.insert(86); + score_set.insert(87); + score_set.insert(88); + score_set.insert(89); + score_set.insert(90); + score_set.insert(91); + score_set.insert(92); + score_set.insert(93); + + let result = predict(score_set).unwrap(); + assert!(result.succeeded); + assert_eq!(result.encoding, Some("Illumina 1.0".to_string())); + assert_eq!(result.observed_min, 26); + assert_eq!(result.observed_max, 93); + } + + #[test] + fn test_predict_sanger() { + let mut score_set: HashSet = HashSet::new(); + score_set.insert(0); + score_set.insert(1); + score_set.insert(2); + score_set.insert(3); + score_set.insert(4); + score_set.insert(5); + score_set.insert(6); + score_set.insert(7); + score_set.insert(8); + score_set.insert(9); + score_set.insert(10); + score_set.insert(11); + score_set.insert(12); + score_set.insert(13); + score_set.insert(14); + score_set.insert(15); + score_set.insert(16); + score_set.insert(17); + score_set.insert(18); + score_set.insert(19); + score_set.insert(20); + score_set.insert(21); + score_set.insert(22); + score_set.insert(23); + score_set.insert(24); + score_set.insert(25); + score_set.insert(26); + score_set.insert(27); + score_set.insert(28); + score_set.insert(29); + score_set.insert(30); + score_set.insert(31); + score_set.insert(32); + score_set.insert(33); + score_set.insert(34); + score_set.insert(35); + score_set.insert(36); + score_set.insert(37); + score_set.insert(38); + score_set.insert(39); + score_set.insert(40); + score_set.insert(41); + score_set.insert(42); + score_set.insert(43); + score_set.insert(44); + score_set.insert(45); + score_set.insert(46); + score_set.insert(47); + score_set.insert(48); + score_set.insert(49); + score_set.insert(50); + score_set.insert(51); + score_set.insert(52); + score_set.insert(53); + score_set.insert(54); + score_set.insert(55); + score_set.insert(56); + score_set.insert(57); + score_set.insert(58); + score_set.insert(59); + score_set.insert(60); + score_set.insert(61); + score_set.insert(62); + score_set.insert(63); + score_set.insert(64); + score_set.insert(65); + score_set.insert(66); + score_set.insert(67); + score_set.insert(68); + + let result = predict(score_set).unwrap(); + assert!(result.succeeded); + assert_eq!(result.encoding, Some("Sanger/Illumina 1.8".to_string())); + assert_eq!(result.observed_min, 0); + assert_eq!(result.observed_max, 68); + } + + #[test] + fn test_predict_fail() { + let score_set: HashSet = HashSet::new(); + let result = predict(score_set); + assert!(result.is_err()); + } + + #[test] + fn test_predict_too_high_max_score() { + let mut score_set: HashSet = HashSet::new(); + score_set.insert(94); + let result = predict(score_set).unwrap(); + assert!(!result.succeeded); + assert_eq!(result.encoding, None); + assert_eq!(result.observed_min, 94); + assert_eq!(result.observed_max, 94); + } +} diff --git a/src/derive/endedness.rs b/src/derive/endedness.rs new file mode 100644 index 0000000..6f8d5cc --- /dev/null +++ b/src/derive/endedness.rs @@ -0,0 +1,4 @@ +//! Supporting functionality for the `ngs derive endedness` subcommand. + +pub mod compute; +pub mod results; diff --git a/src/derive/endedness/compute.rs b/src/derive/endedness/compute.rs new file mode 100644 index 0000000..9482ac1 --- /dev/null +++ b/src/derive/endedness/compute.rs @@ -0,0 +1,614 @@ +//! Module holding the logic for computing the endedness of a BAM. + +use std::collections::{HashMap, HashSet}; +use std::ops::{Add, AddAssign}; +use std::sync::Arc; +use tracing::warn; + +use crate::derive::endedness::results; +use crate::utils::read_groups::ReadGroupPtr; + +/// Struct holding the ordering flags for a single read group. +#[derive(Debug, Clone, Default)] +pub struct OrderingFlagsCounts { + /// The number of reads without 0x1 set. + pub unsegmented: usize, + + /// The number of reads with the first in template flag set. + pub first: usize, + + /// The number of reads with the last in template flag set. + pub last: usize, + + /// The number of reads with both the first and last in template flags set. + pub both: usize, + + /// The number of reads with neither the first nor last in template flags set. + pub neither: usize, +} +impl OrderingFlagsCounts { + /// Creates a new [`OrderingFlagsCounts`]. + pub fn new() -> Self { + OrderingFlagsCounts { + unsegmented: 0, + first: 0, + last: 0, + both: 0, + neither: 0, + } + } +} + +impl Add for OrderingFlagsCounts { + type Output = Self; + + fn add(self, other: Self) -> Self { + OrderingFlagsCounts { + unsegmented: self.unsegmented + other.unsegmented, + first: self.first + other.first, + last: self.last + other.last, + both: self.both + other.both, + neither: self.neither + other.neither, + } + } +} + +impl AddAssign for OrderingFlagsCounts { + fn add_assign(&mut self, other: Self) { + self.unsegmented += other.unsegmented; + self.first += other.first; + self.last += other.last; + self.both += other.both; + self.neither += other.neither; + } +} + +/// Calculate the reads per template overall and for each read group. +fn calculate_reads_per_template( + read_names: HashMap>, + reads_per_template: &mut HashMap, +) -> f64 { + let mut total_reads: usize = 0; + let mut total_templates: usize = 0; + let mut read_group_reads: HashMap = HashMap::new(); + let mut read_group_templates: HashMap = HashMap::new(); + + let mut warning_count: usize = 0; + + for (read_name, read_groups) in read_names.iter() { + let num_reads = read_groups.len(); + total_reads += num_reads; + total_templates += 1; + + let read_group_set: HashSet = read_groups.iter().cloned().collect(); + + if read_group_set.len() == 1 { + // All found read groups assigned to this QNAME are the same. + // We assume this means all the reads came from the same template. + let read_group = Arc::clone(read_group_set.iter().next().unwrap()); + + *read_group_reads.entry(Arc::clone(&read_group)).or_default() += num_reads; + *read_group_templates.entry(read_group).or_default() += 1; + } else { + // The QNAME is in multiple read groups. + // We assume this means the reads came from multiple templates. + // More specifically, we assume that exactly one template will originate from each read group. + warning_count += 1; + match warning_count { + 1..=100 => { + warn!( + "QNAME: '{}' is in multiple read groups: {:?}", + read_name, read_group_set + ); + } + 101 => warn!( + "Too many warnings about QNAMEs in multiple read groups. Stopping warnings." + ), + _ => (), + } + + for read_group in read_groups { + *read_group_reads.entry(Arc::clone(read_group)).or_default() += 1; + } + for read_group in read_group_set { + *read_group_templates.entry(read_group).or_default() += 1; + } + } + } + + if warning_count > 100 { + warn!( + "{} QNAMEs were found in multiple read groups.", + warning_count + ); + } + + let overall_rpt = total_reads as f64 / total_templates as f64; + + for (read_group, num_reads) in read_group_reads.iter() { + let num_templates = read_group_templates.get(read_group).unwrap(); + let rpt = *num_reads as f64 / *num_templates as f64; + reads_per_template.insert(Arc::clone(read_group), rpt); + } + + overall_rpt +} + +fn predict_endedness( + read_group_name: String, + rg_ordering_flags: &OrderingFlagsCounts, + paired_deviance: f64, + reads_per_template: Option, + round_rpt: bool, +) -> results::ReadGroupDerivedEndednessResult { + let unsegmented = rg_ordering_flags.unsegmented; + let first = rg_ordering_flags.first; + let last = rg_ordering_flags.last; + let both = rg_ordering_flags.both; + let neither = rg_ordering_flags.neither; + + // all zeroes (Perform this check before creating the result struct + // so that we don't have to clone the read group name) + if unsegmented == 0 && first == 0 && last == 0 && both == 0 && neither == 0 { + warn!( + "No reads were detected in this read group: {}", + read_group_name + ); + return results::ReadGroupDerivedEndednessResult::new( + read_group_name, + false, + None, + rg_ordering_flags.clone(), + reads_per_template, + ); + } + + let mut result = results::ReadGroupDerivedEndednessResult::new( + read_group_name, + false, + None, + rg_ordering_flags.clone(), + reads_per_template, + ); + + // only unsegmented present + if unsegmented > 0 && first == 0 && last == 0 && both == 0 && neither == 0 { + match reads_per_template { + Some(rpt) => { + if rpt == 1.0 || (round_rpt && rpt.round() as usize == 1) { + result.succeeded = true; + result.endedness = Some(String::from("Single-End")); + } + } + None => { + result.succeeded = true; + result.endedness = Some(String::from("Single-End")); + } + } + return result; + } + // unsegmented reads are present, and so are other types of reads. + if unsegmented > 0 { + return result; + } + // now unsegmented is guarenteed to be 0 + + // only first present + if first > 0 && last == 0 && both == 0 && neither == 0 { + return result; + } + // only last present + if first == 0 && last > 0 && both == 0 && neither == 0 { + return result; + } + // only both present + if first == 0 && last == 0 && both > 0 && neither == 0 { + // Prior logic (before addition of unsegmented checks) left as comment for posterity + // match reads_per_template { + // Some(rpt) => { + // if rpt == 1.0 || (round_rpt && rpt.round() as usize == 1) { + // result.succeeded = true; + // result.endedness = String::from("Single-End"); + // } + // } + // None => { + // result.succeeded = true; + // result.endedness = String::from("Single-End"); + // } + // } + return result; + } + // only neither present + if first == 0 && last == 0 && both == 0 && neither > 0 { + return result; + } + // first/last mixed with both/neither + if (first > 0 || last > 0) && (both > 0 || neither > 0) { + return result; + } + // any mix of both/neither, regardless of first/last + if both > 0 && neither > 0 { + return result; + } + + // both and neither are now guarenteed to be 0 + // We only need to check first and last + + let first_frac = first as f64 / (first + last) as f64; + let lower_limit = 0.5 - paired_deviance; + let upper_limit = 0.5 + paired_deviance; + if (first == last) || (lower_limit <= first_frac && first_frac <= upper_limit) { + match reads_per_template { + Some(rpt) => { + if rpt == 2.0 || (round_rpt && rpt.round() as usize == 2) { + result.succeeded = true; + result.endedness = Some(String::from("Paired-End")); + } + } + None => { + result.succeeded = true; + result.endedness = Some(String::from("Paired-End")); + } + } + } + result +} + +/// Main method to evaluate the collected ordering flags and +/// return a result for the endedness of the file. This may fail, and the +/// resulting [`results::DerivedEndednessResult`] should be evaluated accordingly. +pub fn predict( + ordering_flags: HashMap, + read_names: Option>>, + paired_deviance: f64, + round_rpt: bool, +) -> results::DerivedEndednessResult { + let mut rg_rpts: HashMap = HashMap::new(); + let mut overall_rpt: Option = None; + if let Some(read_names) = read_names { + overall_rpt = Some(calculate_reads_per_template(read_names, &mut rg_rpts)); + } + + let mut overall_flags = OrderingFlagsCounts::new(); + let mut rg_results = Vec::new(); + + for (read_group, rg_ordering_flags) in ordering_flags.iter() { + overall_flags += rg_ordering_flags.clone(); + + let result = predict_endedness( + read_group.to_string(), + rg_ordering_flags, + paired_deviance, + rg_rpts.get(read_group).copied(), + round_rpt, + ); + rg_results.push(result); + } + + let overall_result = predict_endedness( + "overall".to_string(), + &overall_flags, + paired_deviance, + overall_rpt, + round_rpt, + ); + + results::DerivedEndednessResult::new( + overall_result.succeeded, + overall_result.endedness, + overall_flags, + overall_rpt, + rg_results, + ) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_predict_endedness_from_first_and_last() { + let mut ordering_flags: HashMap = HashMap::new(); + ordering_flags.insert( + Arc::new("overall".to_string()), + OrderingFlagsCounts { + unsegmented: 0, + first: 1, + last: 1, + both: 0, + neither: 0, + }, + ); + let result = predict_endedness( + "overall".to_string(), + ordering_flags + .get(&Arc::new("overall".to_string())) + .unwrap(), + 0.0, + None, + false, + ); + assert!(result.succeeded); + assert_eq!(result.endedness, Some("Paired-End".to_string())); + assert_eq!(result.first, 1); + assert_eq!(result.last, 1); + assert_eq!(result.both, 0); + assert_eq!(result.neither, 0); + assert_eq!(result.rpt, None); + } + + #[test] + fn test_predict_endedness_from_unsegmented() { + let mut ordering_flags: HashMap = HashMap::new(); + ordering_flags.insert( + Arc::new("overall".to_string()), + OrderingFlagsCounts { + unsegmented: 1, + first: 0, + last: 0, + both: 0, + neither: 0, + }, + ); + let result = predict_endedness( + "overall".to_string(), + ordering_flags + .get(&Arc::new("overall".to_string())) + .unwrap(), + 0.0, + None, + false, + ); + assert!(result.succeeded); + assert_eq!(result.endedness, Some("Single-End".to_string())); + assert_eq!(result.first, 0); + assert_eq!(result.last, 0); + assert_eq!(result.both, 0); + assert_eq!(result.neither, 0); + assert_eq!(result.rpt, None); + } + + #[test] + fn test_predict_endedness_from_all_zero_counts() { + let mut ordering_flags: HashMap = HashMap::new(); + ordering_flags.insert(Arc::new(String::from("rg1")), OrderingFlagsCounts::new()); + let result = predict_endedness( + String::from("rg1"), + ordering_flags.get(&Arc::new(String::from("rg1"))).unwrap(), + 0.0, + None, + false, + ); + assert!(!result.succeeded); + assert_eq!(result.endedness, None); + assert_eq!(result.first, 0); + assert_eq!(result.last, 0); + assert_eq!(result.both, 0); + assert_eq!(result.neither, 0); + assert_eq!(result.rpt, None); + } + + #[test] + fn test_predict_from_only_first() { + let mut ordering_flags: HashMap = HashMap::new(); + ordering_flags.insert( + Arc::new("overall".to_string()), + OrderingFlagsCounts { + unsegmented: 0, + first: 1, + last: 0, + both: 0, + neither: 0, + }, + ); + let result = predict(ordering_flags, None, 0.0, false); + assert!(!result.succeeded); + assert_eq!(result.endedness, None); + assert_eq!(result.first, 1); + assert_eq!(result.last, 0); + assert_eq!(result.both, 0); + assert_eq!(result.neither, 0); + assert_eq!(result.rpt, None); + assert_eq!(result.read_groups.len(), 1); + } + + #[test] + fn test_predict_from_only_last() { + let mut ordering_flags: HashMap = HashMap::new(); + ordering_flags.insert( + Arc::new("overall".to_string()), + OrderingFlagsCounts { + unsegmented: 0, + first: 0, + last: 1, + both: 0, + neither: 0, + }, + ); + let result = predict(ordering_flags, None, 0.0, false); + assert!(!result.succeeded); + assert_eq!(result.endedness, None); + assert_eq!(result.first, 0); + assert_eq!(result.last, 1); + assert_eq!(result.both, 0); + assert_eq!(result.neither, 0); + assert_eq!(result.rpt, None); + assert_eq!(result.read_groups.len(), 1); + } + + #[test] + fn test_predict_from_only_both() { + let mut ordering_flags: HashMap = HashMap::new(); + ordering_flags.insert( + Arc::new("overall".to_string()), + OrderingFlagsCounts { + unsegmented: 0, + first: 0, + last: 0, + both: 1, + neither: 0, + }, + ); + let result = predict(ordering_flags, None, 0.0, false); + assert!(!result.succeeded); + assert_eq!(result.endedness, None); + assert_eq!(result.first, 0); + assert_eq!(result.last, 0); + assert_eq!(result.both, 1); + assert_eq!(result.neither, 0); + assert_eq!(result.rpt, None); + assert_eq!(result.read_groups.len(), 1); + } + + #[test] + fn test_predict_from_only_neither() { + let mut ordering_flags: HashMap = HashMap::new(); + ordering_flags.insert( + Arc::new("overall".to_string()), + OrderingFlagsCounts { + unsegmented: 0, + first: 0, + last: 0, + both: 0, + neither: 1, + }, + ); + let result = predict(ordering_flags, None, 0.0, false); + assert!(!result.succeeded); + assert_eq!(result.endedness, None); + assert_eq!(result.first, 0); + assert_eq!(result.last, 0); + assert_eq!(result.both, 0); + assert_eq!(result.neither, 1); + assert_eq!(result.rpt, None); + assert_eq!(result.read_groups.len(), 1); + } + + #[test] + fn test_predict_from_first_and_last() { + let mut ordering_flags: HashMap = HashMap::new(); + ordering_flags.insert( + Arc::new("overall".to_string()), + OrderingFlagsCounts { + unsegmented: 0, + first: 1, + last: 1, + both: 0, + neither: 0, + }, + ); + let result = predict(ordering_flags, None, 0.0, false); + assert!(result.succeeded); + assert_eq!(result.endedness, Some("Paired-End".to_string())); + assert_eq!(result.first, 1); + assert_eq!(result.last, 1); + assert_eq!(result.both, 0); + assert_eq!(result.neither, 0); + assert_eq!(result.rpt, None); + assert_eq!(result.read_groups.len(), 1); + } + + #[test] + fn test_calculate_reads_per_template() { + let mut read_names: HashMap> = HashMap::new(); + let rg_paired = Arc::new("rg_paired".to_string()); + let rg_single = Arc::new("rg_single".to_string()); + read_names.insert( + "read1".to_string(), + vec![Arc::clone(&rg_paired), Arc::clone(&rg_paired)], + ); + read_names.insert( + "read2".to_string(), + vec![ + Arc::clone(&rg_paired), + Arc::clone(&rg_paired), + Arc::clone(&rg_single), + ], + ); + read_names.insert("read3".to_string(), vec![Arc::clone(&rg_single)]); + read_names.insert( + "read4".to_string(), + vec![Arc::clone(&rg_paired), Arc::clone(&rg_paired)], + ); + read_names.insert( + "read5".to_string(), + vec![ + Arc::clone(&rg_paired), + Arc::clone(&rg_paired), + Arc::clone(&rg_single), + ], + ); + let mut rg_rpts: HashMap = HashMap::new(); + let overall_rpt = calculate_reads_per_template(read_names, &mut rg_rpts); + assert_eq!(rg_rpts.len(), 2); + assert_eq!(overall_rpt, 2.2); + assert_eq!(rg_rpts.get(&Arc::clone(&rg_paired)).unwrap(), &2.0); + assert_eq!(rg_rpts.get(&Arc::clone(&rg_single)).unwrap(), &1.0); + } + + #[test] + fn test_predict_with_rpt_complex() { + let mut ordering_flags: HashMap = HashMap::new(); + let rg_paired = Arc::new("rg_paired".to_string()); + let rg_single = Arc::new("rg_single".to_string()); + ordering_flags.insert( + Arc::clone(&rg_paired), + OrderingFlagsCounts { + unsegmented: 0, + first: 8, + last: 8, + both: 0, + neither: 0, + }, + ); + ordering_flags.insert( + Arc::clone(&rg_single), + OrderingFlagsCounts { + unsegmented: 2, + first: 0, + last: 0, + both: 0, + neither: 0, + }, + ); + let mut read_names: HashMap> = HashMap::new(); + read_names.insert( + "read1".to_string(), + vec![Arc::clone(&rg_paired), Arc::clone(&rg_paired)], + ); + read_names.insert( + "read2".to_string(), + vec![ + Arc::clone(&rg_paired), + Arc::clone(&rg_paired), + Arc::clone(&rg_single), + ], + ); + read_names.insert("read3".to_string(), vec![Arc::clone(&rg_single)]); + read_names.insert( + "read4".to_string(), + vec![Arc::clone(&rg_paired), Arc::clone(&rg_paired)], + ); + read_names.insert( + "read5".to_string(), + vec![ + Arc::clone(&rg_paired), + Arc::clone(&rg_paired), + Arc::clone(&rg_single), + ], + ); + let result = predict(ordering_flags, Some(read_names), 0.0, false); + assert!(!result.succeeded); + assert_eq!(result.endedness, None); + assert_eq!(result.unsegmented, 2); + assert_eq!(result.first, 8); + assert_eq!(result.last, 8); + assert_eq!(result.both, 0); + assert_eq!(result.neither, 0); + assert_eq!(result.rpt, Some(2.2)); + assert_eq!(result.read_groups.len(), 2); + // We can't know which read group will be first in the vector. + // But both should succeed. + assert!(result.read_groups[0].succeeded && result.read_groups[1].succeeded); + } +} diff --git a/src/derive/endedness/results.rs b/src/derive/endedness/results.rs new file mode 100644 index 0000000..43cddd4 --- /dev/null +++ b/src/derive/endedness/results.rs @@ -0,0 +1,119 @@ +//! Module holding the results structs for the `ngs derive endedness` subcommand. + +use serde::Serialize; + +use crate::derive::endedness::compute::OrderingFlagsCounts; + +/// Struct holding the per read group results for an `ngs derive endedness` +/// subcommand call. +#[derive(Debug, Serialize)] +pub struct ReadGroupDerivedEndednessResult { + /// Name of the read group. + pub read_group: String, + + /// Whether or not an endedness was determined for this read group. + pub succeeded: bool, + + /// The endedness of this read group, if derivable. + pub endedness: Option, + + /// The number of reads without 0x1 set. + pub unsegmented: usize, + + /// The f+l- read count. + pub first: usize, + + /// The f-l+ read count. + pub last: usize, + + /// The f+l+ read count. + pub both: usize, + + /// The f-l- read count. + pub neither: usize, + + /// The reads per template (RPT). + /// Only available if `args.calc_rpt` is true. + pub rpt: Option, +} + +impl ReadGroupDerivedEndednessResult { + /// Creates a new [`ReadGroupDerivedEndednessResult`]. + pub fn new( + read_group: String, + succeeded: bool, + endedness: Option, + counts: OrderingFlagsCounts, + rpt: Option, + ) -> Self { + ReadGroupDerivedEndednessResult { + read_group, + succeeded, + endedness, + unsegmented: counts.unsegmented, + first: counts.first, + last: counts.last, + both: counts.both, + neither: counts.neither, + rpt, + } + } +} + +/// Struct holding the final results for an `ngs derive endedness` subcommand +/// call. +#[derive(Debug, Serialize)] +pub struct DerivedEndednessResult { + /// Whether or not the `ngs derive endedness` subcommand succeeded. + pub succeeded: bool, + + /// The overall endedness, if derivable. + pub endedness: Option, + + /// The number of reads without 0x1 set. + pub unsegmented: usize, + + /// The overall f+l- read count. + pub first: usize, + + /// The overall f-l+ read count. + pub last: usize, + + /// The overall f+l+ read count. + pub both: usize, + + /// The overall f-l- read count. + pub neither: usize, + + /// The overall reads per template (RPT). + /// Only available if `args.calc_rpt` is true. + pub rpt: Option, + + /// Vector of [`ReadGroupDerivedEndednessResult`]s. + /// One for each read group in the BAM, + /// and potentially one for any reads with an unknown read group. + pub read_groups: Vec, +} + +impl DerivedEndednessResult { + /// Creates a new [`DerivedEndednessResult`]. + pub fn new( + succeeded: bool, + endedness: Option, + counts: OrderingFlagsCounts, + rpt: Option, + read_groups: Vec, + ) -> Self { + DerivedEndednessResult { + succeeded, + endedness, + unsegmented: counts.unsegmented, + first: counts.first, + last: counts.last, + both: counts.both, + neither: counts.neither, + rpt, + read_groups, + } + } +} diff --git a/src/derive/instrument/compute.rs b/src/derive/instrument/compute.rs index 48d0106..173e11d 100644 --- a/src/derive/instrument/compute.rs +++ b/src/derive/instrument/compute.rs @@ -1,14 +1,11 @@ //! Combines the flowcell and instrument checks into a single workflow. -use std::collections::HashMap; -use std::collections::HashSet; - use regex::Regex; use serde::Serialize; -use tracing::info; +use std::collections::{HashMap, HashSet}; -use super::flowcells; -use super::instruments; +use crate::derive::instrument::{flowcells, instruments}; +use crate::utils::read_groups::ReadGroupPtr; /// Generalized struct for holding instrument detection results. #[derive(Debug, Default, Serialize)] @@ -42,6 +39,81 @@ impl InstrumentDetectionResults { } } +/// A query for a look-up table and the resulting hits from that table. +#[derive(Debug, Serialize)] +pub struct QueryResult { + /// The query that was used to generate the result. + pub query: String, + + /// The possible instruments that could have generated the query. + pub result: HashSet, +} + +/// Metrics related to how read records were processed. +#[derive(Debug, Default, Serialize)] +pub struct RecordMetrics { + /// The total number of records that were processed. + pub total_records: usize, + + /// The total number of records that couldn't be parsed + /// due to a missing or invalid read name. + pub bad_read_name: usize, + + /// The total number of records that contained a parseable + /// instrument name in their read name. + pub found_instrument_name: usize, + + /// The total number of records that contained a parseable + /// flowcell name in their read name. + pub found_flowcell_name: usize, +} + +/// Struct holding the per read group results for an `ngs derive instrument` +/// subcommand call. +#[derive(Debug, Serialize)] +pub struct ReadGroupDerivedInstrumentResult { + /// The read group that these results are associated with. + pub read_group: String, + + /// Whether or not the `ngs derive instrument` subcommand succeeded + /// for this read group. + pub succeeded: bool, + + /// The possible instruments detected for this read group, if derivable. + pub instruments: Option>, + + /// The level of confidence that the tool has concerning these results. + pub confidence: String, + + /// Status of the evidence that supports (or lack thereof) these predicted + /// instruments, if available. + pub evidence: Option, + + /// A general comment field, if available. + pub comment: Option, + + /// The results of the instrument name look-ups for this read group. + pub instrument_name_queries: Vec, + + /// The results of the flowcell name look-ups for this read group. + pub flowcell_name_queries: Vec, +} + +impl Default for ReadGroupDerivedInstrumentResult { + fn default() -> Self { + ReadGroupDerivedInstrumentResult { + read_group: String::new(), + succeeded: false, + instruments: None, + confidence: "unknown".to_string(), + evidence: None, + comment: None, + instrument_name_queries: Vec::new(), + flowcell_name_queries: Vec::new(), + } + } +} + /// Struct holding the final results for an `ngs derive instrument` subcommand /// call. #[derive(Debug, Serialize)] @@ -50,35 +122,38 @@ pub struct DerivedInstrumentResult { pub succeeded: bool, /// The possible instruments detected by `ngs derive instrument`, if - /// available. + /// derivable. pub instruments: Option>, /// The level of confidence that the tool has concerning these results. pub confidence: String, /// Status of the evidence that supports (or lack thereof) these predicted - /// instruments, if available. + /// instruments, if available. pub evidence: Option, /// A general comment field, if available. pub comment: Option, + + /// Vector of [`ReadGroupDerivedInstrumentResult`]s. + /// One for each read group in the BAM, + /// and potentially one for any reads with an unknown read group. + pub read_groups: Vec, + + /// Metrics related to how read records were processed. + pub records: RecordMetrics, } -impl DerivedInstrumentResult { - /// Creates a new [`DerivedInstrumentResult`]. - pub fn new( - succeeded: bool, - instruments: Option>, - confidence: String, - evidence: Option, - comment: Option, - ) -> Self { +impl Default for DerivedInstrumentResult { + fn default() -> Self { DerivedInstrumentResult { - succeeded, - instruments, - confidence, - evidence, - comment, + succeeded: false, + instruments: None, + confidence: "unknown".to_string(), + evidence: None, + comment: None, + read_groups: Vec::new(), + records: RecordMetrics::default(), } } } @@ -103,25 +178,26 @@ impl DerivedInstrumentResult { pub fn possible_instruments_for_query( query: String, lookup_table: &HashMap<&'static str, HashSet<&'static str>>, -) -> HashSet { - let mut result: HashSet = HashSet::new(); +) -> QueryResult { + let mut result_set: HashSet = HashSet::new(); for (pattern, machines) in lookup_table { let re = Regex::new(pattern).unwrap(); if re.is_match(query.as_str()) { let matching_machines: Vec = machines.iter().map(|x| x.to_string()).collect(); - result.extend(matching_machines); + result_set.extend(matching_machines); } } - - info!(" [*] {}, Possible Instruments: {:?}", query, result); - result + QueryResult { + query, + result: result_set, + } } /// Given a HashSet of unique queries (usually a instrument ID or flowcell ID /// parsed from a read name) that were detected from a SAM/BAM/CRAM file, return /// a HashSet that contains all possible machines that could have generated that -/// list of queries. +/// list of queries and a vec recording the query look-ups that were made. /// /// This is done by iterating through the HashSet of machines that could have /// produced each name and taking the intersection. It is possible, of course, @@ -141,64 +217,56 @@ pub fn possible_instruments_for_query( pub fn predict_instrument( queries: HashSet, lookup_table: &HashMap<&'static str, HashSet<&'static str>>, -) -> InstrumentDetectionResults { +) -> (InstrumentDetectionResults, Vec) { let mut result = InstrumentDetectionResults::default(); + let mut query_results = Vec::new(); for name in queries { let derived = possible_instruments_for_query(name, lookup_table); - result.update_instruments(&derived); + result.update_instruments(&derived.result); + query_results.push(derived); } - result + (result, query_results) } /// Combines evidence from the instrument id detection and flowcell id detection -/// to produce a final [`DerivedInstrumentResult`]. +/// to produce a [`ReadGroupDerivedInstrumentResult`]. pub fn resolve_instrument_prediction( iid_results: InstrumentDetectionResults, fcid_results: InstrumentDetectionResults, -) -> DerivedInstrumentResult { +) -> ReadGroupDerivedInstrumentResult { let possible_instruments_by_iid = iid_results.possible_instruments.unwrap_or_default(); let possible_instruments_by_fcid = fcid_results.possible_instruments.unwrap_or_default(); + let mut result = ReadGroupDerivedInstrumentResult::default(); + // (1) If the set of possible instruments as determined by the instrument id // is empty _and_ we have seen at least one machine, then the only possible // scenario is there are conflicting instrument ids. if possible_instruments_by_iid.is_empty() && iid_results.detected_at_least_one_machine { - return DerivedInstrumentResult::new( - false, - None, - "unknown".to_string(), - Some("instrument id".to_string()), - Some( - "multiple instruments were detected in this file via the instrument id".to_string(), - ), + result.evidence = Some("instrument id".to_string()); + result.comment = Some( + "multiple instruments were detected in this file via the instrument id".to_string(), ); + return result; } // (2) If the set of possible instruments as determined by the flowcell id // is empty _and_ we have seen at least one machine, then the only possible // scenario is there are conflicting flowcell ids. if possible_instruments_by_fcid.is_empty() && fcid_results.detected_at_least_one_machine { - return DerivedInstrumentResult::new( - false, - None, - "unknown".to_string(), - Some("flowcell id".to_string()), - Some("multiple instruments were detected in this file via the flowcell id".to_string()), - ); + result.evidence = Some("flowcell id".to_string()); + result.comment = + Some("multiple instruments were detected in this file via the flowcell id".to_string()); + return result; } // (3) if neither result turns up anything, then we can simply say that the // machine was not able to be detected. if possible_instruments_by_iid.is_empty() && possible_instruments_by_fcid.is_empty() { - return DerivedInstrumentResult::new( - false, - None, - "unknown".to_string(), - None, - Some("no matching instruments were found".to_string()), - ); + result.comment = Some("no matching instruments were found".to_string()); + return result; } // (4) If both aren't empty and iid_results _is_ empty, then the fcid @@ -211,13 +279,11 @@ pub fn resolve_instrument_prediction( _ => "low", }; - return DerivedInstrumentResult::new( - true, - Some(instruments), - confidence.to_string(), - Some("flowcell id".to_string()), - None, - ); + result.succeeded = true; + result.instruments = Some(instruments); + result.confidence = confidence.to_string(); + result.evidence = Some("flowcell id".to_string()); + return result; } // (5) Same as the block above, except now we are evaluating the opposite @@ -229,13 +295,11 @@ pub fn resolve_instrument_prediction( _ => "low", }; - return DerivedInstrumentResult::new( - true, - Some(instruments), - confidence.to_string(), - Some("instrument id".to_string()), - None, - ); + result.succeeded = true; + result.instruments = Some(instruments); + result.confidence = confidence.to_string(); + result.evidence = Some("instrument id".to_string()); + return result; } let overlapping_instruments: HashSet = possible_instruments_by_fcid @@ -244,42 +308,66 @@ pub fn resolve_instrument_prediction( .collect(); if overlapping_instruments.is_empty() { - return DerivedInstrumentResult::new( - false, - None, - "high".to_string(), - Some("instrument and flowcell id".to_string()), - Some( - "Case needs triaging, results from instrument id and \ + result.confidence = "high".to_string(); + result.evidence = Some("instrument and flowcell id".to_string()); + result.comment = Some( + "Case needs triaging, results from instrument id and \ flowcell id are mutually exclusive." - .to_string(), - ), + .to_string(), ); + return result; } - DerivedInstrumentResult::new( - true, - Some(overlapping_instruments), - "high".to_string(), - Some("instrument and flowcell id".to_string()), - None, - ) + result.succeeded = true; + result.instruments = Some(overlapping_instruments); + result.confidence = "high".to_string(); + result.evidence = Some("instrument and flowcell id".to_string()); + result } /// Main method to evaluate the detected instrument names and flowcell names and /// return a result for the derived instruments. This may fail, and the /// resulting [`DerivedInstrumentResult`] should be evaluated accordingly. pub fn predict( - instrument_names: HashSet, - flowcell_names: HashSet, + instrument_names: HashMap>, + flowcell_names: HashMap>, ) -> DerivedInstrumentResult { let instruments = instruments::build_instrument_lookup_table(); let flowcells = flowcells::build_flowcell_lookup_table(); - let iid_results = predict_instrument(instrument_names, &instruments); - let fcid_results = predict_instrument(flowcell_names, &flowcells); + let mut rg_results = Vec::new(); + let mut all_instrument_names = HashSet::new(); + let mut all_flowcell_names = HashSet::new(); + + for rg in instrument_names.keys() { + all_instrument_names.extend(instrument_names[rg].iter().cloned()); + all_flowcell_names.extend(flowcell_names[rg].iter().cloned()); + + let (rg_iid_results, rg_instrument_name_queries) = + predict_instrument(instrument_names[rg].clone(), &instruments); + let (rg_fcid_results, rg_flowcell_name_queries) = + predict_instrument(flowcell_names[rg].clone(), &flowcells); + + let mut rg_result = resolve_instrument_prediction(rg_iid_results, rg_fcid_results); + rg_result.read_group = rg.to_string(); + rg_result.instrument_name_queries = rg_instrument_name_queries; + rg_result.flowcell_name_queries = rg_flowcell_name_queries; + rg_results.push(rg_result); + } - resolve_instrument_prediction(iid_results, fcid_results) + let (iid_results, _) = predict_instrument(all_instrument_names, &instruments); + let (fcid_results, _) = predict_instrument(all_flowcell_names, &flowcells); + + let overall_prediction = resolve_instrument_prediction(iid_results, fcid_results); + DerivedInstrumentResult { + succeeded: overall_prediction.succeeded, + instruments: overall_prediction.instruments, + confidence: overall_prediction.confidence, + evidence: overall_prediction.evidence, + comment: overall_prediction.comment, + read_groups: rg_results, + ..DerivedInstrumentResult::default() + } } #[cfg(test)] @@ -290,36 +378,42 @@ mod tests { fn test_derive_instrument_from_invalid_instrument_name() { let instruments = instruments::build_instrument_lookup_table(); let result = possible_instruments_for_query(String::from("NoMatchingName"), &instruments); - assert!(result.is_empty()); + assert!(result.result.is_empty()); } #[test] fn test_derive_instrument_from_valid_instrument_name() { let instruments = instruments::build_instrument_lookup_table(); let result = possible_instruments_for_query(String::from("A00000"), &instruments); - assert_eq!(result.len(), 1); - assert!(result.contains("NovaSeq")); + assert_eq!(result.result.len(), 1); + assert!(result.result.contains("NovaSeq")); } #[test] fn test_derive_instrument_from_invalid_flowcell_name() { let flowcells = flowcells::build_flowcell_lookup_table(); let result = possible_instruments_for_query(String::from("NoMatchingName"), &flowcells); - assert!(result.is_empty()); + assert!(result.result.is_empty()); } #[test] fn test_derive_instrument_from_valid_flowcell_name() { let flowcells = flowcells::build_flowcell_lookup_table(); let result = possible_instruments_for_query(String::from("H00000RXX"), &flowcells); - assert_eq!(result.len(), 1); - assert!(result.contains("NovaSeq")); + assert_eq!(result.result.len(), 1); + assert!(result.result.contains("NovaSeq")); } #[test] fn test_derive_instrument_novaseq_succesfully() { - let detected_iids = HashSet::from(["A00000".to_string()]); - let detected_fcids = HashSet::from(["H00000RXX".to_string()]); + let detected_iids = HashMap::from([( + ReadGroupPtr::from("RG1".to_string()), + HashSet::from(["A00000".to_string()]), + )]); + let detected_fcids = HashMap::from([( + ReadGroupPtr::from("RG1".to_string()), + HashSet::from(["H00000RXX".to_string()]), + )]); let result = predict(detected_iids, detected_fcids); assert!(result.succeeded); @@ -337,8 +431,23 @@ mod tests { #[test] fn test_derive_instrument_conflicting_instrument_ids() { - let detected_iids = HashSet::from(["A00000".to_string(), "D00000".to_string()]); - let detected_fcids = HashSet::from(["H00000RXX".to_string()]); + let detected_iids = HashMap::from([ + ( + ReadGroupPtr::from("RG1".to_string()), + HashSet::from(["A00000".to_string()]), + ), + ( + ReadGroupPtr::from("RG2".to_string()), + HashSet::from(["D00000".to_string()]), + ), + ]); + let detected_fcids = HashMap::from([ + ( + ReadGroupPtr::from("RG1".to_string()), + HashSet::from(["H00000RXX".to_string()]), + ), + (ReadGroupPtr::from("RG2".to_string()), HashSet::new()), + ]); let result = predict(detected_iids, detected_fcids); assert!(!result.succeeded); @@ -351,12 +460,30 @@ mod tests { "multiple instruments were detected in this file via the instrument id".to_string() ) ); + // We can't know which read group will be first in the vector. + // But both should succeed. + assert!(result.read_groups[0].succeeded && result.read_groups[1].succeeded); } #[test] fn test_derive_instrument_conflicting_flowcell_ids() { - let detected_iids = HashSet::from(["A00000".to_string()]); - let detected_fcids = HashSet::from(["H00000RXX".to_string(), "B0000".to_string()]); + let detected_iids = HashMap::from([ + ( + ReadGroupPtr::from("RG1".to_string()), + HashSet::from(["A00000".to_string()]), + ), + (ReadGroupPtr::from("RG2".to_string()), HashSet::new()), + ]); + let detected_fcids = HashMap::from([ + ( + ReadGroupPtr::from("RG1".to_string()), + HashSet::from(["H00000RXX".to_string()]), + ), + ( + ReadGroupPtr::from("RG2".to_string()), + HashSet::from(["B0000".to_string()]), + ), + ]); let result = predict(detected_iids, detected_fcids); assert!(!result.succeeded); @@ -367,12 +494,19 @@ mod tests { result.comment, Some("multiple instruments were detected in this file via the flowcell id".to_string()) ); + // We can't know which read group will be first in the vector. + // But both should succeed. + assert!(result.read_groups[0].succeeded && result.read_groups[1].succeeded); } #[test] fn test_derive_instrument_medium_instrument_evidence() { - let detected_iids = HashSet::from(["A00000".to_string()]); - let detected_fcids = HashSet::new(); + let detected_iids = HashMap::from([( + ReadGroupPtr::from("RG1".to_string()), + HashSet::from(["A00000".to_string()]), + )]); + let detected_fcids = + HashMap::from([(ReadGroupPtr::from("RG1".to_string()), HashSet::new())]); let result = predict(detected_iids, detected_fcids); assert!(result.succeeded); @@ -387,8 +521,12 @@ mod tests { #[test] fn test_derive_instrument_low_instrument_evidence() { - let detected_iids = HashSet::from(["K00000".to_string()]); - let detected_fcids = HashSet::new(); + let detected_iids = HashMap::from([( + ReadGroupPtr::from("RG1".to_string()), + HashSet::from(["K00000".to_string()]), + )]); + let detected_fcids = + HashMap::from([(ReadGroupPtr::from("RG1".to_string()), HashSet::new())]); let result = predict(detected_iids, detected_fcids); assert!(result.succeeded); @@ -406,8 +544,12 @@ mod tests { #[test] fn test_derive_instrument_medium_flowcell_evidence() { - let detected_iids = HashSet::new(); - let detected_fcids = HashSet::from(["H00000RXX".to_string()]); + let detected_iids = + HashMap::from([(ReadGroupPtr::from("RG1".to_string()), HashSet::new())]); + let detected_fcids = HashMap::from([( + ReadGroupPtr::from("RG1".to_string()), + HashSet::from(["H00000RXX".to_string()]), + )]); let result = predict(detected_iids, detected_fcids); assert!(result.succeeded); @@ -422,8 +564,12 @@ mod tests { #[test] fn test_derive_instrument_low_flowcell_evidence() { - let detected_iids = HashSet::new(); - let detected_fcids = HashSet::from(["H0000ADXX".to_string()]); + let detected_iids = + HashMap::from([(ReadGroupPtr::from("RG1".to_string()), HashSet::new())]); + let detected_fcids = HashMap::from([( + ReadGroupPtr::from("RG1".to_string()), + HashSet::from(["H0000ADXX".to_string()]), + )]); let result = predict(detected_iids, detected_fcids); assert!(result.succeeded); @@ -442,8 +588,14 @@ mod tests { #[test] fn test_derive_instrument_conflicting_flowcell_and_instrument_evidence() { - let detected_iids = HashSet::from(["K00000".to_string()]); - let detected_fcids = HashSet::from(["H00000RXX".to_string()]); + let detected_iids = HashMap::from([( + ReadGroupPtr::from("RG1".to_string()), + HashSet::from(["K00000".to_string()]), + )]); + let detected_fcids = HashMap::from([( + ReadGroupPtr::from("RG1".to_string()), + HashSet::from(["H00000RXX".to_string()]), + )]); let result = predict(detected_iids, detected_fcids); assert!(!result.succeeded); @@ -458,8 +610,14 @@ mod tests { #[test] fn test_derive_instrument_no_matches() { - let detected_iids = HashSet::from(["QQQQQ".to_string()]); - let detected_fcids = HashSet::from(["ZZZZZZ".to_string()]); + let detected_iids = HashMap::from([( + ReadGroupPtr::from("RG1".to_string()), + HashSet::from(["QQQQQ".to_string()]), + )]); + let detected_fcids = HashMap::from([( + ReadGroupPtr::from("RG1".to_string()), + HashSet::from(["ZZZZZZ".to_string()]), + )]); let result = predict(detected_iids, detected_fcids); assert!(!result.succeeded); diff --git a/src/derive/instrument/flowcells.rs b/src/derive/instrument/flowcells.rs index 4bb2389..b339cfa 100644 --- a/src/derive/instrument/flowcells.rs +++ b/src/derive/instrument/flowcells.rs @@ -1,7 +1,6 @@ //! Knowledge about which flowcells map to which machine types. -use std::collections::HashMap; -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; /// Encapsulates the knowledge we currently have on which flowcell patterns map /// to which machine types as a [`HashMap`]. diff --git a/src/derive/instrument/instruments.rs b/src/derive/instrument/instruments.rs index e10b114..a1df971 100644 --- a/src/derive/instrument/instruments.rs +++ b/src/derive/instrument/instruments.rs @@ -1,7 +1,6 @@ //! Knowledge about which instrument ids map to which machine types. -use std::collections::HashMap; -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; /// Encapsulates the knowledge we currently have on which instrument name patterns map /// to which machine types as a [`HashMap`]. diff --git a/src/derive/junction_annotation.rs b/src/derive/junction_annotation.rs new file mode 100644 index 0000000..77e2617 --- /dev/null +++ b/src/derive/junction_annotation.rs @@ -0,0 +1,4 @@ +//! Supporting functionality for the `ngs derive junction-annotation` subcommand. + +pub mod compute; +pub mod results; diff --git a/src/derive/junction_annotation/compute.rs b/src/derive/junction_annotation/compute.rs new file mode 100644 index 0000000..e50ce99 --- /dev/null +++ b/src/derive/junction_annotation/compute.rs @@ -0,0 +1,1074 @@ +//! Module holding the logic for annotating junctions. + +use anyhow::{bail, Ok}; +use noodles::core::Position; +use noodles::sam::alignment::Record; +use noodles::sam::record::cigar::op::Kind; +use noodles::sam::record::MappingQuality; +use noodles::sam::Header; +use std::collections::{HashMap, HashSet}; + +use crate::derive::junction_annotation::results; +use crate::utils::alignment::filter_by_mapq; + +/// Struct to hold starts and ends of exons. +pub struct ExonSets<'a> { + /// Starts of exons, grouped by contig. + pub starts: HashMap<&'a str, HashSet>, + + /// ends of exons, grouped by contig. + pub ends: HashMap<&'a str, HashSet>, +} + +/// Parameters defining how to annotate found junctions +pub struct JunctionAnnotationParameters { + /// Minimum intron length to consider. + pub min_intron_length: usize, + + /// Minimum number of reads supporting a junction to be considered. + pub min_read_support: usize, + + /// Minumum mapping quality for a record to be considered. + /// `None` means no filtering by MAPQ. This also allows + /// for records _without_ a MAPQ to be counted. + pub min_mapq: Option, + + /// Do not count supplementary alignments. + pub no_supplementary: bool, + + /// Do count secondary alignments. + pub count_secondary: bool, + + /// Do count duplicates. + pub count_duplicates: bool, +} + +/// Function for incrementing a junction counter by one. +fn increment_junction_counter( + junction_counter: &mut results::JunctionCounter, + junction: results::Junction, +) { + junction_counter + .entry(junction) + .and_modify(|e| *e += 1) + .or_insert(1); +} + +/// Function for incrementing a junction map by one. +fn increment_junction_map( + junction_map: &mut results::JunctionsMap, + ref_name: &str, + junction: results::Junction, +) { + increment_junction_counter( + junction_map.entry(ref_name.to_string()).or_default(), + junction, + ); +} + +/// Function to filter out records based on their flags. +fn filter_by_flags(record: &Record, params: &JunctionAnnotationParameters) -> bool { + let flags = record.flags(); + if flags.is_unmapped() + || (params.no_supplementary && flags.is_supplementary()) + || (!params.count_secondary && flags.is_secondary()) + || (!params.count_duplicates && flags.is_duplicate()) + { + return true; + } + false +} + +/// Function to filter out records that don't have introns. +fn filter_by_cigar(record: &Record) -> bool { + !record + .cigar() + .iter() + .any(|op| matches!(op.kind(), Kind::Skip)) +} + +/// Main function to annotate junctions one record at a time. +pub fn process( + record: &Record, + exons: &ExonSets<'_>, + header: &Header, + params: &JunctionAnnotationParameters, + results: &mut results::JunctionAnnotationResults, +) -> anyhow::Result<()> { + // (1) Parse the read name. + let read_name = match record.read_name() { + Some(name) => name, + _ => bail!("Could not parse read name"), + }; + + // (2) Filter by record flags. + if filter_by_flags(record, params) { + results.records.filtered_by_flags += 1; + return Ok(()); + } + + // (3) Filter by CIGAR. + // We only care about reads with introns, so if there are no introns + // we can skip this read. + if filter_by_cigar(record) { + results.records.not_spliced += 1; + return Ok(()); + } + + // (4) Filter by MAPQ + if filter_by_mapq(record, params.min_mapq) { + results.records.bad_mapq += 1; + return Ok(()); + } + + // (5) Parse the reference sequence from the record. + let (seq_name, _) = match record.reference_sequence(header) { + Some(seq_map_result) => seq_map_result?, + _ => { + bail!( + "Could not parse reference sequence id for read: {}", + read_name + ) + } + }; + let seq_name = seq_name.as_str(); + + // (6) Check if there will be annotations for this reference sequence. + let mut ref_is_annotated = true; + if !exons.starts.contains_key(seq_name) || !exons.ends.contains_key(seq_name) { + ref_is_annotated = false; + } + + // (7) Calculate the start position of this read. This will + // be used to find the position of any introns. + let start = match record.alignment_start() { + Some(s) => s, + _ => bail!("Could not parse record's start position."), + }; + + // (8) Find introns + let mut cur_pos = start; + for op in record.cigar().iter() { + match op.kind() { + // This is an intron. + Kind::Skip => { + // Check that `op.len() >= params.min_intron_length` later, + // once all reads supporting short junctions have been collected + // for better metric reporting. + + let intron_start = cur_pos; + // Update cur_pos to the end of the intron. + cur_pos = cur_pos.checked_add(op.len()).unwrap(); + let intron_end = cur_pos; + let junction: results::Junction = (intron_start, intron_end); + + // If the reference sequence is not annotated, we can skip + // the lookup of exon positions, and directly insert the + // intron into the unannotated_reference HashMap. + if !ref_is_annotated { + increment_junction_map( + &mut results.junction_annotations.unannotated_reference, + seq_name, + junction, + ); + continue; + } + + // The following unwraps are safe because we checked that the reference + // sequence is annotated above. + let exon_starts = exons.starts.get(seq_name).unwrap(); + let exon_ends = exons.ends.get(seq_name).unwrap(); + + let mut intron_start_known = false; + let mut intron_end_known = false; + if exon_ends.contains(&intron_start) { + intron_start_known = true; + } + if exon_starts.contains(&intron_end) { + intron_end_known = true; + } + + let junction_map = match (intron_start_known, intron_end_known) { + (true, true) => { + // We found both ends of the intron. + // This is a Known Junction. + &mut results.junction_annotations.known + } + (true, false) | (false, true) => { + // We found one end of the intron, + // but not the other. + // This is a Partial Novel Junction. + &mut results.junction_annotations.partial_novel + } + (false, false) => { + // We found neither end of the intron. + // This is a Complete Novel Junction. + &mut results.junction_annotations.complete_novel + } + }; + increment_junction_map(junction_map, seq_name, junction) + } + // Operations that increment the reference position (beside Skip which is handled above). + Kind::Match | Kind::Deletion | Kind::SequenceMatch | Kind::SequenceMismatch => { + cur_pos = cur_pos.checked_add(op.len()).unwrap(); + } + // Operations that do not increment the reference position. + _ => {} + } + } + + results.records.processed += 1; + Ok(()) +} + +/// Function to filter out junctions that are too short or don't have enough read support. +fn filter_junction_map( + junction_map: &mut results::JunctionsMap, + min_intron_length: usize, + min_read_support: usize, + metrics: &mut results::SummaryResults, +) { + junction_map.retain(|_, v| { + v.retain(|(start, end), count| { + let mut keep = true; + if end.get() - start.get() < min_intron_length { + metrics.intron_too_short += 1; + keep = false; + } + if *count < min_read_support { + metrics.junctions_with_not_enough_read_support += 1; + keep = false; + } + if !keep { + metrics.total_rejected_junctions += 1; + } + keep + }); + !v.is_empty() + }); +} + +/// Function to tally up the junctions and their read support. +fn tally_junctions_and_support(junction_map: &results::JunctionsMap) -> (usize, usize) { + let junctions = junction_map.values().map(|v| v.len()).sum(); + let support = junction_map + .values() + .map(|v| v.values().sum::()) + .sum(); + (junctions, support) +} + +/// Main function to summarize the results of the junction-annotation subcommand. +pub fn summarize( + results: &mut results::JunctionAnnotationResults, + params: &JunctionAnnotationParameters, +) { + // Filter out junctions that are too short or don't have enough read support. + filter_junction_map( + &mut results.junction_annotations.known, + params.min_intron_length, + params.min_read_support, + &mut results.summary, + ); + filter_junction_map( + &mut results.junction_annotations.partial_novel, + params.min_intron_length, + params.min_read_support, + &mut results.summary, + ); + filter_junction_map( + &mut results.junction_annotations.complete_novel, + params.min_intron_length, + params.min_read_support, + &mut results.summary, + ); + filter_junction_map( + &mut results.junction_annotations.unannotated_reference, + params.min_intron_length, + params.min_read_support, + &mut results.summary, + ); + + // Tally up observed junctions and spliced reads. + ( + results.summary.known_junctions, + results.summary.known_junctions_read_support, + ) = tally_junctions_and_support(&results.junction_annotations.known); + ( + results.summary.partial_novel_junctions, + results.summary.partial_novel_junctions_read_support, + ) = tally_junctions_and_support(&results.junction_annotations.partial_novel); + ( + results.summary.complete_novel_junctions, + results.summary.complete_novel_junctions_read_support, + ) = tally_junctions_and_support(&results.junction_annotations.complete_novel); + ( + results.summary.unannotated_reference_junctions, + results.summary.unannotated_reference_junctions_read_support, + ) = tally_junctions_and_support(&results.junction_annotations.unannotated_reference); + + // Tally up total junctions. + results.summary.total_junctions = results.summary.known_junctions + + results.summary.partial_novel_junctions + + results.summary.complete_novel_junctions + + results.summary.unannotated_reference_junctions; + // Tally up total read support. + results.summary.total_junctions_read_support = results.summary.known_junctions_read_support + + results.summary.partial_novel_junctions_read_support + + results.summary.complete_novel_junctions_read_support + + results.summary.unannotated_reference_junctions_read_support; + + // Calculate percentages. + let total_junctions = results.summary.total_junctions as f64 + - results.summary.unannotated_reference_junctions as f64; // exclude unannotated junctions from percentages + results.summary.known_junctions_percent = + results.summary.known_junctions as f64 / total_junctions * 100.0; + results.summary.partial_novel_junctions_percent = + results.summary.partial_novel_junctions as f64 / total_junctions * 100.0; + results.summary.complete_novel_junctions_percent = + results.summary.complete_novel_junctions as f64 / total_junctions * 100.0; + + // Calculate average read support. + // Total + results.summary.average_junction_read_support = results.summary.total_junctions_read_support + as f64 + / results.summary.total_junctions as f64; + // Known + results.summary.average_known_junction_read_support = + results.summary.known_junctions_read_support as f64 + / results.summary.known_junctions as f64; + // Partial Novel + results.summary.average_partial_novel_junction_read_support = + results.summary.partial_novel_junctions_read_support as f64 + / results.summary.partial_novel_junctions as f64; + // Complete Novel + results.summary.average_complete_novel_junction_read_support = + results.summary.complete_novel_junctions_read_support as f64 + / results.summary.complete_novel_junctions as f64; +} + +#[cfg(test)] +mod tests { + use super::*; + use noodles::core::Position; + use noodles::sam::header::record::value::map; + use noodles::sam::header::record::value::map::header::Version; + use noodles::sam::header::record::value::map::{Map, ReferenceSequence}; + use noodles::sam::record::MappingQuality; + use noodles::sam::record::ReadName; + use std::num::NonZeroUsize; + + fn create_test_exons() -> ExonSets<'static> { + let exon_starts: HashMap<&str, HashSet> = HashMap::from([( + "sq1", + HashSet::from([ + Position::new(1).unwrap(), + Position::new(11).unwrap(), + Position::new(21).unwrap(), + Position::new(31).unwrap(), + Position::new(41).unwrap(), + Position::new(51).unwrap(), + Position::new(61).unwrap(), + Position::new(71).unwrap(), + ]), + )]); + let exon_ends: HashMap<&str, HashSet> = exon_starts + .iter() + .map(|(k, v)| (*k, v.iter().map(|e| e.checked_add(10).unwrap()).collect())) + .collect::>>(); + let exons: ExonSets<'_> = ExonSets { + starts: exon_starts, + ends: exon_ends, + }; + exons + } + + fn create_test_header() -> Header { + Header::builder() + .set_header(Map::::new(Version::new(1, 6))) + .add_reference_sequence( + "sq1".parse().unwrap(), + Map::::new(NonZeroUsize::try_from(800).unwrap()), + ) + .build() + } + + #[test] + fn test_filter_by_flags() { + // Setup + let mut record = Record::default(); + let params = JunctionAnnotationParameters { + min_intron_length: 10, + min_read_support: 2, + min_mapq: Some(MappingQuality::new(30).unwrap()), + no_supplementary: false, + count_secondary: false, + count_duplicates: false, + }; + + // Test that records are filtered out correctly + record.flags_mut().set(0x4.into(), true); + assert!(filter_by_flags(&record, ¶ms)); + record.flags_mut().set(0x4.into(), false); + record.flags_mut().set(0x800.into(), true); + assert!(!filter_by_flags(&record, ¶ms)); + record.flags_mut().set(0x800.into(), false); + record.flags_mut().set(0x100.into(), true); + assert!(filter_by_flags(&record, ¶ms)); + record.flags_mut().set(0x100.into(), false); + record.flags_mut().set(0x400.into(), true); + assert!(filter_by_flags(&record, ¶ms)); + record.flags_mut().set(0x400.into(), false); + assert!(!filter_by_flags(&record, ¶ms)); + } + + #[test] + fn test_filter_by_cigar() { + // Setup + let mut record = Record::default(); + + // Test that records are filtered out correctly + *record.cigar_mut() = "10M10N10M".parse().unwrap(); + assert!(!filter_by_cigar(&record)); + *record.cigar_mut() = "10M".parse().unwrap(); + assert!(filter_by_cigar(&record)); + } + + #[test] + fn test_filter_junction_map() { + // Setup + let mut junction_map = results::JunctionsMap::default(); + junction_map.insert( + "sq1".to_string(), + HashMap::from([ + ((Position::new(1).unwrap(), Position::new(10).unwrap()), 3), + ((Position::new(1).unwrap(), Position::new(11).unwrap()), 1), + ((Position::new(1).unwrap(), Position::new(5).unwrap()), 1), + ]), + ); + junction_map.insert( + "sq2".to_string(), + HashMap::from([((Position::new(1).unwrap(), Position::new(11).unwrap()), 2)]), + ); + let min_intron_length = 10; + let min_read_support = 2; + let mut metrics = results::SummaryResults::default(); + + // Test that junctions are filtered out correctly + filter_junction_map( + &mut junction_map, + min_intron_length, + min_read_support, + &mut metrics, + ); + assert_eq!(junction_map.len(), 1); + assert_eq!(junction_map.get("sq1"), None); + assert_eq!(junction_map.get("sq2").unwrap().len(), 1); + assert_eq!(metrics.intron_too_short, 2); + assert_eq!(metrics.junctions_with_not_enough_read_support, 2); + assert_eq!(metrics.total_rejected_junctions, 3); + } + + #[test] + fn test_tally_junctions_and_support() { + // Setup + let mut junction_map = results::JunctionsMap::default(); + junction_map.insert( + "sq1".to_string(), + HashMap::from([ + ((Position::new(1).unwrap(), Position::new(11).unwrap()), 1), + ((Position::new(1).unwrap(), Position::new(5).unwrap()), 1), + ]), + ); + junction_map.insert( + "sq2".to_string(), + HashMap::from([((Position::new(1).unwrap(), Position::new(11).unwrap()), 2)]), + ); + + // Test that junctions are tallied correctly + let (juncs, support) = tally_junctions_and_support(&junction_map); + assert_eq!(juncs, 3); + assert_eq!(support, 4); + } + + #[test] + fn test_process_known_junction() { + // Setup + let mut results = results::JunctionAnnotationResults::default(); + let params = JunctionAnnotationParameters { + min_intron_length: 10, + min_read_support: 2, + min_mapq: Some(MappingQuality::new(30).unwrap()), + no_supplementary: false, + count_secondary: false, + count_duplicates: false, + }; + let exons = create_test_exons(); + let header = create_test_header(); + + // Test known junction + let mut record = Record::default(); + let r1_name: ReadName = "known1".parse().unwrap(); + *record.read_name_mut() = Some(r1_name); + *record.reference_sequence_id_mut() = Some(0); + *record.alignment_start_mut() = Position::new(1); + *record.cigar_mut() = "10M10N10M".parse().unwrap(); + *record.mapping_quality_mut() = MappingQuality::new(60); + record.flags_mut().set(0x4.into(), false); + process(&record, &exons, &header, ¶ms, &mut results).unwrap(); + assert_eq!(results.records.processed, 1); + assert_eq!(results.records.filtered_by_flags, 0); + assert_eq!(results.records.not_spliced, 0); + assert_eq!(results.records.bad_mapq, 0); + assert_eq!(results.junction_annotations.known.len(), 1); + assert_eq!( + results.junction_annotations.known.get("sq1").unwrap().len(), + 1 + ); + assert_eq!( + results + .junction_annotations + .known + .get("sq1") + .unwrap() + .get(&(Position::new(11).unwrap(), Position::new(21).unwrap())) + .unwrap(), + &1 + ); + } + + #[test] + fn test_process_partial_novel_junction() { + // Setup + let mut results = results::JunctionAnnotationResults::default(); + let params = JunctionAnnotationParameters { + min_intron_length: 10, + min_read_support: 2, + min_mapq: Some(MappingQuality::new(30).unwrap()), + no_supplementary: false, + count_secondary: false, + count_duplicates: false, + }; + let exons = create_test_exons(); + let header = create_test_header(); + + // Test partial novel junction + let mut record = Record::default(); + let r1_name: ReadName = "partial1".parse().unwrap(); + *record.read_name_mut() = Some(r1_name); + *record.reference_sequence_id_mut() = Some(0); + *record.alignment_start_mut() = Position::new(1); + *record.cigar_mut() = "10M12N10M".parse().unwrap(); + *record.mapping_quality_mut() = MappingQuality::new(60); + record.flags_mut().set(0x4.into(), false); + process(&record, &exons, &header, ¶ms, &mut results).unwrap(); + assert_eq!(results.records.processed, 1); + assert_eq!(results.records.filtered_by_flags, 0); + assert_eq!(results.records.not_spliced, 0); + assert_eq!(results.records.bad_mapq, 0); + assert_eq!(results.junction_annotations.partial_novel.len(), 1); + assert_eq!( + results + .junction_annotations + .partial_novel + .get("sq1") + .unwrap() + .len(), + 1 + ); + assert_eq!( + results + .junction_annotations + .partial_novel + .get("sq1") + .unwrap() + .get(&(Position::new(11).unwrap(), Position::new(23).unwrap())) + .unwrap(), + &1 + ); + } + + #[test] + fn test_process_complete_novel_junction() { + // Setup + let mut results = results::JunctionAnnotationResults::default(); + let params = JunctionAnnotationParameters { + min_intron_length: 10, + min_read_support: 2, + min_mapq: Some(MappingQuality::new(30).unwrap()), + no_supplementary: false, + count_secondary: false, + count_duplicates: false, + }; + let exons = create_test_exons(); + let header = create_test_header(); + + // Test complete novel junction + let mut record = Record::default(); + let r1_name: ReadName = "complete1".parse().unwrap(); + *record.read_name_mut() = Some(r1_name); + *record.reference_sequence_id_mut() = Some(0); + *record.alignment_start_mut() = Position::new(1); + *record.cigar_mut() = "85M14N10M".parse().unwrap(); + *record.mapping_quality_mut() = MappingQuality::new(60); + record.flags_mut().set(0x4.into(), false); + process(&record, &exons, &header, ¶ms, &mut results).unwrap(); + assert_eq!(results.records.processed, 1); + assert_eq!(results.records.filtered_by_flags, 0); + assert_eq!(results.records.not_spliced, 0); + assert_eq!(results.records.bad_mapq, 0); + assert_eq!(results.junction_annotations.complete_novel.len(), 1); + assert_eq!( + results + .junction_annotations + .complete_novel + .get("sq1") + .unwrap() + .len(), + 1 + ); + assert_eq!( + results + .junction_annotations + .complete_novel + .get("sq1") + .unwrap() + .get(&(Position::new(86).unwrap(), Position::new(100).unwrap())) + .unwrap(), + &1 + ); + } + + #[test] + fn test_process_ignores_unmapped() { + // Setup + let mut results = results::JunctionAnnotationResults::default(); + let params = JunctionAnnotationParameters { + min_intron_length: 10, + min_read_support: 2, + min_mapq: Some(MappingQuality::new(30).unwrap()), + no_supplementary: false, + count_secondary: false, + count_duplicates: false, + }; + let exons = create_test_exons(); + let header = create_test_header(); + + // Test that unmapped gets ignored + let mut record = Record::default(); + let r1_name: ReadName = "unmapped".parse().unwrap(); + *record.read_name_mut() = Some(r1_name); + *record.reference_sequence_id_mut() = Some(0); + *record.alignment_start_mut() = Position::new(1); + *record.cigar_mut() = "10M10N10M".parse().unwrap(); + *record.mapping_quality_mut() = MappingQuality::new(60); + record.flags_mut().set(0x4.into(), true); + process(&record, &exons, &header, ¶ms, &mut results).unwrap(); + assert_eq!(results.records.processed, 0); + assert_eq!(results.records.filtered_by_flags, 1); + assert_eq!(results.records.not_spliced, 0); + assert_eq!(results.records.bad_mapq, 0); + } + + #[test] + fn test_process_supplementary_toggle() { + // Setup + let mut results = results::JunctionAnnotationResults::default(); + let mut params = JunctionAnnotationParameters { + min_intron_length: 10, + min_read_support: 2, + min_mapq: Some(MappingQuality::new(30).unwrap()), + no_supplementary: true, + count_secondary: false, + count_duplicates: false, + }; + let exons = create_test_exons(); + let header = create_test_header(); + + // Test that supplementary gets ignored + let mut record = Record::default(); + let r1_name: ReadName = "supplementary1".parse().unwrap(); + *record.read_name_mut() = Some(r1_name); + *record.reference_sequence_id_mut() = Some(0); + *record.alignment_start_mut() = Position::new(1); + *record.cigar_mut() = "10M10N10M".parse().unwrap(); + *record.mapping_quality_mut() = MappingQuality::new(60); + record.flags_mut().set(0x4.into(), false); + record.flags_mut().set(0x800.into(), true); + process(&record, &exons, &header, ¶ms, &mut results).unwrap(); + assert_eq!(results.records.processed, 0); + assert_eq!(results.records.filtered_by_flags, 1); + assert_eq!(results.records.not_spliced, 0); + assert_eq!(results.records.bad_mapq, 0); + + // Test that supplementary gets processed + params.no_supplementary = false; + + let mut record = Record::default(); + let r2_name = "supplementary2".parse().unwrap(); + *record.read_name_mut() = Some(r2_name); + *record.reference_sequence_id_mut() = Some(0); + *record.alignment_start_mut() = Position::new(1); + *record.cigar_mut() = "10M10N10M".parse().unwrap(); + *record.mapping_quality_mut() = MappingQuality::new(60); + record.flags_mut().set(0x4.into(), false); + record.flags_mut().set(0x800.into(), false); + process(&record, &exons, &header, ¶ms, &mut results).unwrap(); + assert_eq!(results.records.processed, 1); + assert_eq!(results.records.filtered_by_flags, 1); + assert_eq!(results.records.not_spliced, 0); + assert_eq!(results.records.bad_mapq, 0); + } + + #[test] + fn test_process_secondary_toggle() { + // Setup + let mut results = results::JunctionAnnotationResults::default(); + let mut params = JunctionAnnotationParameters { + min_intron_length: 10, + min_read_support: 2, + min_mapq: Some(MappingQuality::new(30).unwrap()), + no_supplementary: false, + count_secondary: true, + count_duplicates: false, + }; + let exons = create_test_exons(); + let header = create_test_header(); + + // Test that secondary gets processed + let mut record = Record::default(); + let r1_name: ReadName = "secondary1".parse().unwrap(); + *record.read_name_mut() = Some(r1_name); + *record.reference_sequence_id_mut() = Some(0); + *record.alignment_start_mut() = Position::new(1); + *record.cigar_mut() = "10M10N10M".parse().unwrap(); + *record.mapping_quality_mut() = MappingQuality::new(60); + record.flags_mut().set(0x4.into(), false); + record.flags_mut().set(0x100.into(), true); + process(&record, &exons, &header, ¶ms, &mut results).unwrap(); + assert_eq!(results.records.processed, 1); + assert_eq!(results.records.filtered_by_flags, 0); + assert_eq!(results.records.not_spliced, 0); + assert_eq!(results.records.bad_mapq, 0); + + // Test that secondary gets ignored + params.count_secondary = false; + + let mut record = Record::default(); + let r2_name = "secondary2".parse().unwrap(); + *record.read_name_mut() = Some(r2_name); + *record.reference_sequence_id_mut() = Some(0); + *record.alignment_start_mut() = Position::new(1); + *record.cigar_mut() = "10M10N10M".parse().unwrap(); + *record.mapping_quality_mut() = MappingQuality::new(60); + record.flags_mut().set(0x4.into(), false); + record.flags_mut().set(0x100.into(), true); + process(&record, &exons, &header, ¶ms, &mut results).unwrap(); + assert_eq!(results.records.processed, 1); + assert_eq!(results.records.filtered_by_flags, 1); + assert_eq!(results.records.not_spliced, 0); + assert_eq!(results.records.bad_mapq, 0); + } + + #[test] + fn test_process_mapq_toggle() { + // Setup + let mut results = results::JunctionAnnotationResults::default(); + let mut params = JunctionAnnotationParameters { + min_intron_length: 10, + min_read_support: 2, + min_mapq: Some(MappingQuality::new(30).unwrap()), + no_supplementary: false, + count_secondary: false, + count_duplicates: false, + }; + let exons = create_test_exons(); + let header = create_test_header(); + + // Test that mapq gets processed + let mut record = Record::default(); + let r1_name: ReadName = "mapq1".parse().unwrap(); + *record.read_name_mut() = Some(r1_name); + *record.reference_sequence_id_mut() = Some(0); + *record.alignment_start_mut() = Position::new(1); + *record.cigar_mut() = "10M10N10M".parse().unwrap(); + *record.mapping_quality_mut() = MappingQuality::new(60); + record.flags_mut().set(0x4.into(), false); + process(&record, &exons, &header, ¶ms, &mut results).unwrap(); + assert_eq!(results.records.processed, 1); + assert_eq!(results.records.filtered_by_flags, 0); + assert_eq!(results.records.not_spliced, 0); + assert_eq!(results.records.bad_mapq, 0); + + // Test that mapq gets ignored + params.min_mapq = Some(MappingQuality::new(61).unwrap()); + + let mut record = Record::default(); + let r2_name = "mapq2".parse().unwrap(); + *record.read_name_mut() = Some(r2_name); + *record.reference_sequence_id_mut() = Some(0); + *record.alignment_start_mut() = Position::new(1); + *record.cigar_mut() = "10M10N10M".parse().unwrap(); + *record.mapping_quality_mut() = MappingQuality::new(60); + record.flags_mut().set(0x4.into(), false); + process(&record, &exons, &header, ¶ms, &mut results).unwrap(); + assert_eq!(results.records.processed, 1); + assert_eq!(results.records.filtered_by_flags, 0); + assert_eq!(results.records.not_spliced, 0); + assert_eq!(results.records.bad_mapq, 1); + } + + #[test] + fn test_process_intron_too_short() { + // Setup + let mut results = results::JunctionAnnotationResults::default(); + let params = JunctionAnnotationParameters { + min_intron_length: 10, + min_read_support: 2, + min_mapq: Some(MappingQuality::new(30).unwrap()), + no_supplementary: false, + count_secondary: false, + count_duplicates: false, + }; + let exons = create_test_exons(); + let header = create_test_header(); + + // Test that intron too short gets processed + let mut record = Record::default(); + let r1_name: ReadName = "short1".parse().unwrap(); + *record.read_name_mut() = Some(r1_name); + *record.reference_sequence_id_mut() = Some(0); + *record.alignment_start_mut() = Position::new(1); + *record.cigar_mut() = "10M5N10M".parse().unwrap(); + *record.mapping_quality_mut() = MappingQuality::new(60); + record.flags_mut().set(0x4.into(), false); + process(&record, &exons, &header, ¶ms, &mut results).unwrap(); + assert_eq!(results.records.processed, 1); // processed at first, gets filtered later + assert_eq!(results.records.filtered_by_flags, 0); + assert_eq!(results.records.not_spliced, 0); + assert_eq!(results.records.bad_mapq, 0); + } + + #[test] + fn test_process_multiple_junctions() { + // Setup + let mut results = results::JunctionAnnotationResults::default(); + let params = JunctionAnnotationParameters { + min_intron_length: 10, + min_read_support: 2, + min_mapq: Some(MappingQuality::new(30).unwrap()), + no_supplementary: false, + count_secondary: false, + count_duplicates: false, + }; + let exons = create_test_exons(); + let header = create_test_header(); + + // Test that multiple junctions are processed + let mut record = Record::default(); + let r1_name: ReadName = "long_read".parse().unwrap(); + *record.read_name_mut() = Some(r1_name); + *record.reference_sequence_id_mut() = Some(0); + *record.alignment_start_mut() = Position::new(1); + *record.cigar_mut() = "10M10N10M10N10M10N10M".parse().unwrap(); + *record.mapping_quality_mut() = MappingQuality::new(60); + record.flags_mut().set(0x4.into(), false); + process(&record, &exons, &header, ¶ms, &mut results).unwrap(); + assert_eq!(results.records.processed, 1); + assert_eq!(results.records.filtered_by_flags, 0); + assert_eq!(results.records.not_spliced, 0); + assert_eq!(results.records.bad_mapq, 0); + assert_eq!(results.junction_annotations.known.len(), 1); + assert_eq!( + results.junction_annotations.known.get("sq1").unwrap().len(), + 3 + ); + assert_eq!( + results + .junction_annotations + .known + .get("sq1") + .unwrap() + .get(&(Position::new(11).unwrap(), Position::new(21).unwrap())) + .unwrap(), + &1 + ); + assert_eq!( + results + .junction_annotations + .known + .get("sq1") + .unwrap() + .get(&(Position::new(31).unwrap(), Position::new(41).unwrap())) + .unwrap(), + &1 + ); + assert_eq!( + results + .junction_annotations + .known + .get("sq1") + .unwrap() + .get(&(Position::new(51).unwrap(), Position::new(61).unwrap())) + .unwrap(), + &1 + ); + } + + #[test] + fn test_process_unspliced_read() { + // Setup + let mut results = results::JunctionAnnotationResults::default(); + let params = JunctionAnnotationParameters { + min_intron_length: 10, + min_read_support: 2, + min_mapq: Some(MappingQuality::new(30).unwrap()), + no_supplementary: false, + count_secondary: false, + count_duplicates: false, + }; + let exons = create_test_exons(); + let header = create_test_header(); + + // Test that unspliced gets ignored + let mut record = Record::default(); + let r1_name: ReadName = "unspliced".parse().unwrap(); + *record.read_name_mut() = Some(r1_name); + *record.reference_sequence_id_mut() = Some(0); + *record.alignment_start_mut() = Position::new(1); + *record.cigar_mut() = "10M".parse().unwrap(); + *record.mapping_quality_mut() = MappingQuality::new(60); + record.flags_mut().set(0x4.into(), false); + process(&record, &exons, &header, ¶ms, &mut results).unwrap(); + assert_eq!(results.records.processed, 0); + assert_eq!(results.records.filtered_by_flags, 0); + assert_eq!(results.records.not_spliced, 1); + assert_eq!(results.records.bad_mapq, 0); + } + + #[test] + fn test_process_unannotated_reference() { + // Setup + let mut results = results::JunctionAnnotationResults::default(); + let params = JunctionAnnotationParameters { + min_intron_length: 10, + min_read_support: 2, + min_mapq: Some(MappingQuality::new(30).unwrap()), + no_supplementary: false, + count_secondary: false, + count_duplicates: false, + }; + let rand_header = Header::builder() + .set_header(Map::::new(Version::new(1, 6))) + .add_reference_sequence( + "sq1_random".parse().unwrap(), + Map::::new(NonZeroUsize::try_from(800).unwrap()), + ) + .build(); + let exons = create_test_exons(); + + // Test that unannotated reference gets processed + let mut record = Record::default(); + let r1_name: ReadName = "unannotated".parse().unwrap(); + *record.read_name_mut() = Some(r1_name); + *record.reference_sequence_id_mut() = Some(0); + *record.alignment_start_mut() = Position::new(1); + *record.cigar_mut() = "10M10N10M".parse().unwrap(); + *record.mapping_quality_mut() = MappingQuality::new(60); + record.flags_mut().set(0x4.into(), false); + process(&record, &exons, &rand_header, ¶ms, &mut results).unwrap(); + assert_eq!(results.records.processed, 1); + assert_eq!(results.records.filtered_by_flags, 0); + assert_eq!(results.records.not_spliced, 0); + assert_eq!(results.records.bad_mapq, 0); + assert_eq!(results.junction_annotations.unannotated_reference.len(), 1); + assert_eq!( + results + .junction_annotations + .unannotated_reference + .get("sq1_random") + .unwrap() + .len(), + 1 + ); + assert_eq!( + results + .junction_annotations + .unannotated_reference + .get("sq1_random") + .unwrap() + .get(&(Position::new(11).unwrap(), Position::new(21).unwrap())) + .unwrap(), + &1 + ); + } + + #[test] + fn test_summarize() { + // Setup + let mut results = results::JunctionAnnotationResults::default(); + let params = JunctionAnnotationParameters { + min_intron_length: 10, + min_read_support: 2, + min_mapq: Some(MappingQuality::new(30).unwrap()), + no_supplementary: false, + count_secondary: false, + count_duplicates: false, + }; + results.junction_annotations.known.insert( + "sq1".to_string(), + HashMap::from([ + ((Position::new(11).unwrap(), Position::new(21).unwrap()), 4), + ((Position::new(31).unwrap(), Position::new(41).unwrap()), 1), + ((Position::new(21).unwrap(), Position::new(41).unwrap()), 3), + ]), + ); + results.junction_annotations.partial_novel.insert( + "sq1".to_string(), + HashMap::from([ + ((Position::new(11).unwrap(), Position::new(37).unwrap()), 3), + ((Position::new(11).unwrap(), Position::new(15).unwrap()), 2), + ]), + ); + results.junction_annotations.complete_novel.insert( + "sq1".to_string(), + HashMap::from([( + (Position::new(103).unwrap(), Position::new(117).unwrap()), + 2, + )]), + ); + results.junction_annotations.unannotated_reference.insert( + "sq1_random".to_string(), + HashMap::from([((Position::new(1).unwrap(), Position::new(11).unwrap()), 5)]), + ); + + // Test that results are summarized correctly + summarize(&mut results, ¶ms); + assert_eq!(results.summary.known_junctions, 2); + assert_eq!(results.summary.known_junctions_read_support, 7); + assert_eq!(results.summary.partial_novel_junctions, 1); + assert_eq!(results.summary.partial_novel_junctions_read_support, 3); + assert_eq!(results.summary.complete_novel_junctions, 1); + assert_eq!(results.summary.complete_novel_junctions_read_support, 2); + assert_eq!(results.summary.unannotated_reference_junctions, 1); + assert_eq!( + results.summary.unannotated_reference_junctions_read_support, + 5 + ); + assert_eq!(results.summary.total_junctions, 5); + assert_eq!(results.summary.total_junctions_read_support, 17); + assert_eq!(results.summary.known_junctions_percent, 50.0); + assert_eq!(results.summary.partial_novel_junctions_percent, 25.0); + assert_eq!(results.summary.complete_novel_junctions_percent, 25.0); + assert_eq!(results.summary.average_junction_read_support, 3.4); + assert_eq!(results.summary.average_known_junction_read_support, 3.5); + assert_eq!( + results.summary.average_partial_novel_junction_read_support, + 3.0 + ); + assert_eq!( + results.summary.average_complete_novel_junction_read_support, + 2.0 + ); + } +} diff --git a/src/derive/junction_annotation/results.rs b/src/derive/junction_annotation/results.rs new file mode 100644 index 0000000..0ce8f86 --- /dev/null +++ b/src/derive/junction_annotation/results.rs @@ -0,0 +1,218 @@ +//! Results related to the `ngs derive junction-annotation` subcommand. + +use noodles::core::Position; +use serde::ser::SerializeStruct; +use serde::{Serialize, Serializer}; +use std::collections::HashMap; + +/// A junction is a tuple of (start, end) coordinates. +pub type Junction = (Position, Position); + +/// A junction counter is a HashMap where the key is a junction and the value is the number of +/// spliced reads that support the junction. +pub type JunctionCounter = HashMap; + +/// A map of junctions. The key is the reference name, and the value is a JunctionCounter. +pub type JunctionsMap = HashMap; + +/// Lists of annotated junctions. +#[derive(Clone, Debug, Default)] +pub struct JunctionAnnotations { + /// Known junctions. The outer key is the referece name, and the value is another + /// HashMap. The inner key is the (start, end) coordinates of a junction, + /// and the value is the number of spliced reads that support the junction. + pub known: JunctionsMap, + + /// Partially novel junctions. The outer key is the referece name, and the value is another + /// HashMap. The inner key is the (start, end) coordinates of a junction, + /// and the value is the number of spliced reads that support the junction. + pub partial_novel: JunctionsMap, + + /// Complete novel junctions. The outer key is the referece name, and the value is another + /// HashMap. The inner key is the (start, end) coordinates of a junction, + /// and the value is the number of spliced reads that support the junction. + pub complete_novel: JunctionsMap, + + /// Junctions on reference sequences for which junction annotations were not found. + /// The outer key is the referece name, and the value is another + /// HashMap. The inner key is the (start, end) coordinates of a junction, + /// and the value is the number of spliced reads that support the junction. + pub unannotated_reference: JunctionsMap, +} + +// TODO should contigs be sorted? +impl Serialize for JunctionAnnotations { + fn serialize(&self, serializer: S) -> Result { + let mut known = Vec::new(); + for (ref_name, junctions) in &self.known { + let mut junctions_vec = Vec::new(); + for ((start, end), count) in junctions { + junctions_vec.push((start.get(), end.get(), count)); + } + known.push((ref_name.clone(), junctions_vec)); + } + + let mut partial_novel = Vec::new(); + for (ref_name, junctions) in &self.partial_novel { + let mut junctions_vec = Vec::new(); + for ((start, end), count) in junctions { + junctions_vec.push((start.get(), end.get(), count)); + } + partial_novel.push((ref_name.clone(), junctions_vec)); + } + + let mut complete_novel = Vec::new(); + for (ref_name, junctions) in &self.complete_novel { + let mut junctions_vec = Vec::new(); + for ((start, end), count) in junctions { + junctions_vec.push((start.get(), end.get(), count)); + } + complete_novel.push((ref_name.clone(), junctions_vec)); + } + + let mut unannotated_reference = Vec::new(); + for (ref_name, junctions) in &self.unannotated_reference { + let mut junctions_vec = Vec::new(); + for ((start, end), count) in junctions { + junctions_vec.push((start.get(), end.get(), count)); + } + unannotated_reference.push((ref_name.clone(), junctions_vec)); + } + + let mut s = serializer.serialize_struct("JunctionAnnotations", 4)?; + s.serialize_field("known", &known)?; + s.serialize_field("partial_novel", &partial_novel)?; + s.serialize_field("complete_novel", &complete_novel)?; + s.serialize_field("unannotated_reference", &unannotated_reference)?; + s.end() + } +} + +/// General record metrics that are tallied as a part of the +/// junction-annotation subcommand. +#[derive(Clone, Debug, Default, Serialize)] +pub struct RecordMetrics { + /// The number of records that have been fully processed. + /// This is the number of spliced records that have been considered. + pub processed: usize, + + /// The number of records that have been ignored because of their flags. + /// (i.e. they were unmapped, duplicates, secondary, or supplementary) + /// The last 3 conditions can be toggled on/off with CL flags + pub filtered_by_flags: usize, + + /// The number of records that have been ignored because they were not + /// spliced. + pub not_spliced: usize, + + /// The number of records with junctions that have been ignored because + /// they failed the MAPQ filter. + /// This could either mean the MAPQ was too low or it was missing. + pub bad_mapq: usize, +} + +/// Summary statistics for the junction-annotation subcommand. +#[derive(Clone, Default, Debug, Serialize)] +pub struct SummaryResults { + /// The total number of junctions observed in the file. + pub total_junctions: usize, + + /// The total number of splices observed in the file. + /// More than one splice can be observed per read, especially + /// with long read data, so this number is not necessarily equal + /// to the number of spliced reads. It may be greater. + pub total_junctions_read_support: usize, + + /// The average number of spliced reads supporting a junction. + pub average_junction_read_support: f64, + + /// The total number of known junctions observed in the file. + pub known_junctions: usize, + + ///The total number of partially novel junctions observed in the file. + pub partial_novel_junctions: usize, + + /// The total number of complete novel junctions observed in the file. + pub complete_novel_junctions: usize, + + /// The total number of junctions on reference sequences for which junction + /// annotations were not found. + pub unannotated_reference_junctions: usize, + + /// The number of reads supporting known junctions. + /// If a read supports more than one known junction, it is counted more than once. + /// A read with more one junction may also contribute to the support of + /// partially novel or completely novel junctions. + pub known_junctions_read_support: usize, + + /// The number of reads supporting partially novel junctions. + /// If a read supports more than one partially novel junction, it is counted more than once. + /// A read with more one junction may also contribute to the support of + /// known or completely novel junctions. + pub partial_novel_junctions_read_support: usize, + + /// The number of reads supporting completely novel junctions. + /// If a read supports more than one completely novel junction, it is counted more than once. + /// A read with more one junction may also contribute to the support of + /// known or partially novel junctions. + pub complete_novel_junctions_read_support: usize, + + /// The number of reads supporting junctions on reference sequences for which + /// junction annotations were not found. + /// If a read supports more than one junction, it is counted more than once. + pub unannotated_reference_junctions_read_support: usize, + + /// The percentage of junctions that are known. + /// This percentage excludes junctions on reference sequences for which + /// junction annotations were not found. + pub known_junctions_percent: f64, + + /// The percentage of junctions that are partially novel. + /// This percentage excludes junctions on reference sequences for which + /// junction annotations were not found. + pub partial_novel_junctions_percent: f64, + + /// The percentage of junctions that are completely novel. + /// This percentage excludes junctions on reference sequences for which + /// junction annotations were not found. + pub complete_novel_junctions_percent: f64, + + /// Average number of reads supporting known junctions. + pub average_known_junction_read_support: f64, + + /// Average number of reads supporting partially novel junctions. + pub average_partial_novel_junction_read_support: f64, + + /// Average number of reads supporting completely novel junctions. + pub average_complete_novel_junction_read_support: f64, + + /// The total number of junctions that have been rejected because + /// they failed the --min-read-support or the --min-intron-length filter. + /// A junction can be rejected for both reasons, so this + /// number may not be equal to the sum of junctions_with_not_enough_read_support + /// and intron_too_short. + pub total_rejected_junctions: usize, + + /// The total number of junctions which were discarded due to lack of + /// read support. This is not mutually exclusive with intron_too_short. + pub junctions_with_not_enough_read_support: usize, + + /// The number of junctions that have been ignored because + /// they failed the min_intron_length filter. + /// This is not mutually exclusive with junctions_with_not_enough_read_support. + pub intron_too_short: usize, +} + +/// Main Results struct. This struct aggregates all of the minor metrics structs +/// outlined in this file so they can be kept track of as a unit. +#[derive(Clone, Default, Debug, Serialize)] +pub struct JunctionAnnotationResults { + /// Lists of annotated junctions. + pub junction_annotations: JunctionAnnotations, + + /// General record metrics. + pub records: RecordMetrics, + + /// Summary statistics for the junction-annotation subcommand. + pub summary: SummaryResults, +} diff --git a/src/derive/readlen.rs b/src/derive/readlen.rs new file mode 100644 index 0000000..b988896 --- /dev/null +++ b/src/derive/readlen.rs @@ -0,0 +1,3 @@ +//! Supporting functionality for the `ngs derive readlen` subcommand. + +pub mod compute; diff --git a/src/derive/readlen/compute.rs b/src/derive/readlen/compute.rs new file mode 100644 index 0000000..591092d --- /dev/null +++ b/src/derive/readlen/compute.rs @@ -0,0 +1,229 @@ +//! Module holding the logic for computing the consensus read length. + +use serde::Serialize; +use std::collections::HashMap; +use tracing::warn; + +use crate::utils::read_groups::ReadGroupPtr; + +/// Struct holding the per read group results for an `ngs derive readlen` +/// subcommand call. +#[derive(Debug, Serialize)] +pub struct ReadGroupDerivedReadlenResult { + /// The read group that these results are associated with. + pub read_group: String, + + /// Whether or not the `ngs derive readlen` subcommand succeeded + /// for this read group. + pub succeeded: bool, + + /// The consensus read length, if derivable. + pub consensus_read_length: Option, + + /// The majority vote percentage of the consensus read length. + pub majority_pct_detected: f64, + + /// Status of the evidence that supports (or does not support) the + /// consensus read length. + pub evidence: Vec<(usize, usize)>, +} + +impl ReadGroupDerivedReadlenResult { + /// Creates a new [`ReadGroupDerivedReadlenResult`]. + pub fn new( + read_group: String, + succeeded: bool, + consensus_read_length: Option, + majority_pct_detected: f64, + evidence: Vec<(usize, usize)>, + ) -> Self { + ReadGroupDerivedReadlenResult { + read_group, + succeeded, + consensus_read_length, + majority_pct_detected, + evidence, + } + } +} + +/// Struct holding the final results for an `ngs derive readlen` subcommand +/// call. +#[derive(Debug, Serialize)] +pub struct DerivedReadlenResult { + /// Whether or not the `ngs derive readlen` subcommand succeeded. + pub succeeded: bool, + + /// The consensus read length, if derivable. + pub consensus_read_length: Option, + + /// The majority vote percentage of the consensus read length. + pub majority_pct_detected: f64, + + /// Vector of [`ReadGroupDerivedReadlenResult`]s. + /// One for each read group in the BAM, + /// and potentially one for any reads with an unknown read group. + pub read_groups: Vec, + + /// Status of the evidence that supports (or does not support) the + /// consensus read length. + pub evidence: Vec<(usize, usize)>, +} + +impl DerivedReadlenResult { + /// Creates a new [`DerivedReadlenResult`]. + pub fn new( + succeeded: bool, + consensus_read_length: Option, + majority_pct_detected: f64, + read_groups: Vec, + evidence: Vec<(usize, usize)>, + ) -> Self { + DerivedReadlenResult { + succeeded, + consensus_read_length, + majority_pct_detected, + read_groups, + evidence, + } + } +} + +/// Predicts the consensus read length for a given read group based on the +/// read lengths and a majority vote cutoff. +pub fn predict_readlen( + read_group: String, + read_lengths: &HashMap, + majority_vote_cutoff: f64, +) -> ReadGroupDerivedReadlenResult { + let mut read_lengths: Vec<(usize, usize)> = + read_lengths.iter().map(|(k, v)| (*k, *v)).collect(); + + read_lengths.sort_by(|a, b| b.0.cmp(&a.0)); + + // Tally the number of reads + let num_reads: usize = read_lengths.iter().map(|(_, count)| count).sum(); + + let (majority_detected, consensus_read_length) = match num_reads == 0 { + true => { + warn!("No reads were detected for read group: {}", read_group); + (0.0, None) + } + false => ( + read_lengths[0].1 as f64 / num_reads as f64, + Some(read_lengths[0].0), + ), + }; + + match majority_detected >= majority_vote_cutoff { + true => ReadGroupDerivedReadlenResult::new( + read_group, + true, + consensus_read_length, + majority_detected * 100.0, + read_lengths, + ), + false => ReadGroupDerivedReadlenResult::new( + read_group, + false, + None, + majority_detected * 100.0, + read_lengths, + ), + } +} + +/// Main method to evaluate the collected read lengths and +/// return a result for the consensus read length. This may fail, and the +/// resulting [`DerivedReadlenResult`] should be evaluated accordingly. +pub fn predict( + read_lengths: HashMap>, + majority_vote_cutoff: f64, +) -> DerivedReadlenResult { + // Iterate over the read lengths and predict the consensus read length. + let mut rg_results = Vec::new(); + let mut overall_lengths = HashMap::new(); + + for (read_group, lengths) in read_lengths { + let result = predict_readlen(read_group.to_string(), &lengths, majority_vote_cutoff); + rg_results.push(result); + + for (length, count) in lengths { + *overall_lengths.entry(length).or_default() += count; + } + } + + let overall_result = predict_readlen( + "overall".to_string(), + &overall_lengths, + majority_vote_cutoff, + ); + + // Sort the read lengths by their key for output. + let mut overall_lengths: Vec<(usize, usize)> = overall_lengths.into_iter().collect(); + overall_lengths.sort_by(|a, b| b.0.cmp(&a.0)); + + DerivedReadlenResult::new( + overall_result.succeeded, + overall_result.consensus_read_length, + overall_result.majority_pct_detected, + rg_results, + overall_lengths, + ) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::Arc; + + #[test] + fn test_derive_readlen_from_empty_hashmap() { + let read_lengths = HashMap::new(); + let result = predict(read_lengths, 0.7); + assert!(!result.succeeded); + assert_eq!(result.consensus_read_length, None); + assert_eq!(result.majority_pct_detected, 0.0); + assert_eq!(result.evidence, Vec::new()); + } + + #[test] + fn test_derive_readlen_when_all_readlengths_equal() { + let read_lengths = + HashMap::from([(Arc::new("RG1".to_string()), HashMap::from([(100, 10)]))]); + let result = predict(read_lengths, 1.0); + assert!(result.succeeded); + assert_eq!(result.consensus_read_length, Some(100)); + assert_eq!(result.majority_pct_detected, 100.0); + assert_eq!(result.evidence, Vec::from([(100, 10)])); + } + + #[test] + fn test_derive_readlen_success_when_not_all_readlengths_equal() { + let read_lengths = HashMap::from([( + Arc::new("RG1".to_string()), + HashMap::from([(101, 1000), (100, 5), (99, 5)]), + )]); + let result = predict(read_lengths, 0.7); + assert!(result.succeeded); + assert_eq!(result.consensus_read_length, Some(101)); + assert!(result.majority_pct_detected > 99.0); + assert_eq!(result.evidence, Vec::from([(101, 1000), (100, 5), (99, 5)])); + } + + #[test] + fn test_derive_readlen_fail_when_not_all_readlengths_equal() { + let read_lengths = HashMap::from([ + ( + Arc::new("RG1".to_string()), + HashMap::from([(101, 5), (99, 5)]), + ), + (Arc::new("RG2".to_string()), HashMap::from([(100, 1000)])), + ]); + let result = predict(read_lengths, 0.7); + assert!(!result.succeeded); + assert_eq!(result.consensus_read_length, None); + assert!(result.majority_pct_detected < 0.7); + assert_eq!(result.evidence, Vec::from([(101, 5), (100, 1000), (99, 5)])); + } +} diff --git a/src/derive/strandedness.rs b/src/derive/strandedness.rs new file mode 100644 index 0000000..4349a3c --- /dev/null +++ b/src/derive/strandedness.rs @@ -0,0 +1,4 @@ +//! Supporting functionality for the `ngs derive strandedness` subcommand. + +pub mod compute; +pub mod results; diff --git a/src/derive/strandedness/compute.rs b/src/derive/strandedness/compute.rs new file mode 100644 index 0000000..19a64b4 --- /dev/null +++ b/src/derive/strandedness/compute.rs @@ -0,0 +1,589 @@ +//! Module holding the logic for computing the strandedness. + +use noodles::core::Region; +use noodles::sam::record::MappingQuality; +use noodles::{bam, gff, sam}; +use rand::Rng; +use rust_lapper::Lapper; +use std::collections::{HashMap, HashSet}; +use std::ops::{Add, AddAssign}; +use std::sync::Arc; + +use crate::derive::strandedness::results; +use crate::utils::alignment::filter_by_mapq; +use crate::utils::display::RecordCounter; +use crate::utils::read_groups; + +const STRANDED_THRESHOLD: f64 = 80.0; +const UNSTRANDED_THRESHOLD: f64 = 40.0; + +/// Struct for tracking count results. +#[derive(Clone, Default)] +pub struct Counts { + /// The number of reads that are evidence of Forward Strandedness. + forward: usize, + + /// The number of reads that are evidence of Reverse Strandedness. + reverse: usize, +} + +impl Add for Counts { + type Output = Self; + + fn add(self, other: Self) -> Self { + Self { + forward: self.forward + other.forward, + reverse: self.reverse + other.reverse, + } + } +} + +impl AddAssign for Counts { + fn add_assign(&mut self, other: Self) { + self.forward += other.forward; + self.reverse += other.reverse; + } +} + +/// Struct for valid strand orientations. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum Strand { + /// Forward strand. + Forward, + + /// Reverse strand. + Reverse, +} + +impl From for Strand { + fn from(flags: sam::record::Flags) -> Self { + if flags.is_reverse_complemented() { + Self::Reverse + } else { + Self::Forward + } + } +} + +impl TryFrom for Strand { + type Error = (); + + fn try_from(strand: gff::record::Strand) -> Result { + match strand { + gff::record::Strand::Forward => Ok(Self::Forward), + gff::record::Strand::Reverse => Ok(Self::Reverse), + _ => Err(()), + } + } +} + +/// Struct for tracking the order of segments in a record. +#[derive(Clone, Copy, Debug)] +enum SegmentOrder { + /// The first segment in a record. + First, + + /// The last segment in a record. + Last, +} + +impl TryFrom for SegmentOrder { + type Error = String; + + fn try_from(flags: sam::record::Flags) -> Result { + if !flags.is_segmented() { + Err("Expected segmented record.".to_string()) + } else if flags.is_first_segment() && !flags.is_last_segment() { + Ok(SegmentOrder::First) + } else if flags.is_last_segment() && !flags.is_first_segment() { + Ok(SegmentOrder::Last) + } else { + Err("Expected first or last segment.".to_string()) + } + } +} + +/// Struct holding the parsed BAM file and its index. +/// TODO this code is repeated. Should be in a common module. +/// Will be moved to utils::formats in a future PR. +pub struct ParsedBAMFile { + /// The BAM reader. + pub reader: bam::Reader>, + + /// The BAM header. + pub header: sam::Header, + + /// The BAM index. + pub index: bam::bai::Index, +} + +/// Struct holding the counts for all read groups. +/// Also holds the set of read groups found in the BAM. +pub struct AllReadGroupsCounts { + /// The counts for all read groups. + pub counts: HashMap, Counts>, + + /// The set of read groups found in the BAM. + pub found_rgs: HashSet>, +} + +/// Parameters defining how to calculate strandedness. +pub struct StrandednessParams { + /// The number of genes to test for strandedness. + pub num_genes: usize, + + /// The maximum number of genes to try before giving up. + pub max_genes_per_try: usize, + + /// Minimum number of reads mapped to a gene to be considered + /// for evidence of strandedness. + pub min_reads_per_gene: usize, + + /// Minumum mapping quality for a record to be considered. + /// `None` means no filtering by MAPQ. This allows + /// for records _without_ a MAPQ to be counted. + pub min_mapq: Option, + + /// Allow qc failed reads to be counted. + pub count_qc_failed: bool, + + /// Do not count supplementary alignments. + pub no_supplementary: bool, + + /// Do count secondary alignments. + pub count_secondary: bool, + + /// Do count duplicates. + pub count_duplicates: bool, +} + +/// Function to disqualify a gene based on its strand and its exons' strand. +fn disqualify_gene(gene: &gff::Record, exons: &HashMap<&str, Lapper>) -> bool { + // gene_strand guaranteed to be Forward or Reverse by initialization code. + let gene_strand = Strand::try_from(gene.strand()).unwrap(); + let mut all_on_same_strand = true; + + let at_least_one_exon = match exons.get(gene.reference_sequence_name()) { + Some(intervals) => intervals + .find(gene.start().into(), gene.end().into()) + .all(|exon| { + if exon.val != gene_strand { + all_on_same_strand = false; + } + true + }), + None => false, + }; + + if all_on_same_strand && at_least_one_exon { + return false; + } + true +} + +/// Function to filter out records based on their flags. +fn filter_by_flags(record: &sam::alignment::Record, params: &StrandednessParams) -> bool { + let flags = record.flags(); + if (!params.count_qc_failed && flags.is_qc_fail()) + || (params.no_supplementary && flags.is_supplementary()) + || (!params.count_secondary && flags.is_secondary()) + || (!params.count_duplicates && flags.is_duplicate()) + { + return true; + } + false +} + +/// Function to query the BAM file and filter the records based on the +/// parameters provided. +fn query_and_filter( + parsed_bam: &mut ParsedBAMFile, + gene: &gff::Record, + params: &StrandednessParams, + read_metrics: &mut results::ReadRecordMetrics, +) -> Vec { + let start = gene.start(); + let end = gene.end(); + let region = Region::new(gene.reference_sequence_name(), start..=end); + + let mut filtered_reads = Vec::new(); + + let query = parsed_bam + .reader + .query(&parsed_bam.header, &parsed_bam.index, ®ion) + .unwrap(); + for read in query { + let read = read.unwrap(); + + // (1) Filter by flags. + if filter_by_flags(&read, params) { + read_metrics.filtered_by_flags += 1; + continue; + } + + // (2) Filter by MAPQ. + if filter_by_mapq(&read, params.min_mapq) { + read_metrics.bad_mapq += 1; + continue; + } + + filtered_reads.push(read); + } + + if filtered_reads.len() < params.min_reads_per_gene { + filtered_reads.clear(); + } + + filtered_reads +} + +/// Function to classify a read based on its strand and the strand of the gene. +fn classify_read( + read: &sam::alignment::Record, + gene_strand: &Strand, + all_counts: &mut AllReadGroupsCounts, + read_metrics: &mut results::ReadRecordMetrics, +) { + let read_group = read_groups::get_read_group(read, Some(&mut all_counts.found_rgs)); + + let rg_counts = all_counts.counts.entry(read_group).or_default(); + + let read_strand = Strand::from(read.flags()); + if read.flags().is_segmented() { + read_metrics.paired_end_reads += 1; + + let order = SegmentOrder::try_from(read.flags()).unwrap(); + + match (order, read_strand, gene_strand) { + (SegmentOrder::First, Strand::Forward, Strand::Forward) + | (SegmentOrder::First, Strand::Reverse, Strand::Reverse) + | (SegmentOrder::Last, Strand::Forward, Strand::Reverse) + | (SegmentOrder::Last, Strand::Reverse, Strand::Forward) => { + rg_counts.forward += 1; + } + (SegmentOrder::First, Strand::Forward, Strand::Reverse) + | (SegmentOrder::First, Strand::Reverse, Strand::Forward) + | (SegmentOrder::Last, Strand::Forward, Strand::Forward) + | (SegmentOrder::Last, Strand::Reverse, Strand::Reverse) => { + rg_counts.reverse += 1; + } + } + } else { + read_metrics.single_end_reads += 1; + + match (read_strand, gene_strand) { + (Strand::Forward, Strand::Forward) | (Strand::Reverse, Strand::Reverse) => { + rg_counts.forward += 1; + } + (Strand::Forward, Strand::Reverse) | (Strand::Reverse, Strand::Forward) => { + rg_counts.reverse += 1; + } + } + } +} + +/// Method to predict the strandedness of a read group. +pub fn predict_strandedness( + rg_name: &str, + counts: &Counts, +) -> results::ReadGroupDerivedStrandednessResult { + if counts.forward == 0 && counts.reverse == 0 { + return results::ReadGroupDerivedStrandednessResult { + read_group: rg_name.to_string(), + succeeded: false, + strandedness: None, + total: 0, + forward: 0, + reverse: 0, + forward_pct: 0.0, + reverse_pct: 0.0, + }; + } + let mut result = results::ReadGroupDerivedStrandednessResult::new( + rg_name.to_string(), + false, + None, + counts.forward, + counts.reverse, + ); + + if result.forward_pct > STRANDED_THRESHOLD { + result.succeeded = true; + result.strandedness = Some("Forward".to_string()); + } else if result.reverse_pct > STRANDED_THRESHOLD { + result.succeeded = true; + result.strandedness = Some("Reverse".to_string()); + } else if result.forward_pct > UNSTRANDED_THRESHOLD && result.reverse_pct > UNSTRANDED_THRESHOLD + { + result.succeeded = true; + result.strandedness = Some("Unstranded".to_string()); + } // else did not succeed + + result +} + +/// Main method to evaluate the observed strand state and +/// return a result for the derived strandedness. This may fail, and the +/// resulting [`results::DerivedStrandednessResult`] should be evaluated accordingly. +pub fn predict( + parsed_bam: &mut ParsedBAMFile, + gene_records: &mut Vec, + exons: &HashMap<&str, Lapper>, + all_counts: &mut AllReadGroupsCounts, + params: &StrandednessParams, + metrics: &mut results::RecordTracker, +) -> Result { + let mut rng = rand::thread_rng(); + let mut num_genes_considered = 0; // Local to this attempt + let mut counter = RecordCounter::new(Some(1_000)); // Also local to this attempt + let genes_remaining = gene_records.len(); + + let max_iters = if params.max_genes_per_try > genes_remaining { + tracing::warn!( + "The number of genes remaining ({}) is less than the --max-genes-per-try ({}).", + genes_remaining, + params.max_genes_per_try, + ); + genes_remaining + } else { + params.max_genes_per_try + }; + + for _ in 0..max_iters { + if num_genes_considered >= params.num_genes { + tracing::info!( + "Reached the maximum number of considered genes ({}) for this try.", + num_genes_considered, + ); + break; + } + + let cur_gene = gene_records.swap_remove(rng.gen_range(0..gene_records.len())); + counter.inc(); // Technically this is off-by-one, as the gene hasn't been processed yet. + + if disqualify_gene(&cur_gene, exons) { + metrics.genes.mixed_strands += 1; // Tracked across attempts + continue; + } + // gene_strand guaranteed to be Forward or Reverse by initialization code. + let cur_gene_strand = Strand::try_from(cur_gene.strand()).unwrap(); + + let mut enough_reads = false; + for read in query_and_filter(parsed_bam, &cur_gene, params, &mut metrics.reads) { + enough_reads = true; + + classify_read(&read, &cur_gene_strand, all_counts, &mut metrics.reads); + } + if enough_reads { + num_genes_considered += 1; + } else { + metrics.genes.not_enough_reads += 1; // Tracked across attempts + } + } + if num_genes_considered < params.num_genes { + tracing::warn!( + "Evaluated the maximum number of genes ({}) before considering the requested amount of genes ({}) for this try. Only considering an additional {} genes this try.", + max_iters, + params.num_genes, + num_genes_considered, + ); + } + + metrics.genes.considered += num_genes_considered; // Add to any other attempts + metrics.genes.evaluated += counter.get(); // Add to any other attempts + + let mut overall_counts = Counts::default(); + let mut rg_results = Vec::new(); + for (rg, counts) in &all_counts.counts { + overall_counts += counts.clone(); + + let result = predict_strandedness(rg, counts); + rg_results.push(result) + } + + let overall_result = predict_strandedness("overall", &overall_counts); + let final_result = results::DerivedStrandednessResult::new( + overall_result.succeeded, + overall_result.strandedness, + overall_result.forward, + overall_result.reverse, + rg_results, + metrics.clone(), + ); + + anyhow::Ok(final_result) +} + +#[cfg(test)] +mod tests { + use super::*; + use noodles::sam::record::data::field::Tag; + use rust_lapper::Interval; + + #[test] + fn test_disqualify_gene() { + // test mixed strands + let mut exons = HashMap::new(); + exons.insert( + "chr1", + Lapper::new(vec![ + Interval { + start: 1, + stop: 10, + val: Strand::Forward, + }, + Interval { + start: 11, + stop: 20, + val: Strand::Reverse, + }, + ]), + ); + + let s = "chr1\tNOODLES\tgene\t5\t14\t.\t+\t.\tgene_id=ndls0;gene_name=gene0"; + let record = s.parse::().unwrap(); + assert!(disqualify_gene(&record, &exons)); // disqualified + + // test all on same strand + let mut exons = HashMap::new(); + exons.insert( + "chr1", + Lapper::new(vec![ + Interval { + start: 1, + stop: 10, + val: Strand::Forward, + }, + Interval { + start: 11, + stop: 20, + val: Strand::Forward, + }, + ]), + ); + + assert!(!disqualify_gene(&record, &exons)); // accepted + + // test no exons + let exons = HashMap::new(); + assert!(disqualify_gene(&record, &exons)); // disqualified + } + + #[test] + fn test_classify_read() { + // Set up + let mut all_counts = AllReadGroupsCounts { + counts: HashMap::new(), + found_rgs: HashSet::new(), + }; + let mut read_metrics = results::ReadRecordMetrics::default(); + let counts_key = Arc::new("rg1".to_string()); + let rg_tag = sam::record::data::field::Value::String("rg1".to_string()); + + // Test Single-End read. Evidence for Forward Strandedness. + let mut read = sam::alignment::Record::default(); + read.flags_mut().set(0x1.into(), false); + read.data_mut().insert(Tag::ReadGroup, rg_tag.clone()); + classify_read(&read, &Strand::Forward, &mut all_counts, &mut read_metrics); + assert_eq!(read_metrics.paired_end_reads, 0); + assert_eq!(read_metrics.single_end_reads, 1); + assert_eq!(read_metrics.filtered_by_flags, 0); + assert_eq!(read_metrics.bad_mapq, 0); + let counts = all_counts.counts.get(&counts_key).unwrap(); + assert_eq!(counts.forward, 1); + assert_eq!(counts.reverse, 0); + + // Test Paired-End read. Evidence for Forward Strandedness. + let mut read = sam::alignment::Record::default(); + read.flags_mut().set(0x1.into(), true); + read.flags_mut().set(0x40.into(), true); + read.data_mut().insert(Tag::ReadGroup, rg_tag.clone()); + classify_read(&read, &Strand::Forward, &mut all_counts, &mut read_metrics); + assert_eq!(read_metrics.paired_end_reads, 1); + assert_eq!(read_metrics.single_end_reads, 1); + assert_eq!(read_metrics.filtered_by_flags, 0); + assert_eq!(read_metrics.bad_mapq, 0); + let counts = all_counts.counts.get(&counts_key).unwrap(); + assert_eq!(counts.forward, 2); + assert_eq!(counts.reverse, 0); + + // Test Paired-End read. Evidence for Forward Strandedness. + let mut read = sam::alignment::Record::default(); + read.flags_mut().set(0x1.into(), true); + read.flags_mut().set(0x80.into(), true); + read.data_mut().insert(Tag::ReadGroup, rg_tag.clone()); + classify_read(&read, &Strand::Reverse, &mut all_counts, &mut read_metrics); + assert_eq!(read_metrics.paired_end_reads, 2); + assert_eq!(read_metrics.single_end_reads, 1); + assert_eq!(read_metrics.filtered_by_flags, 0); + assert_eq!(read_metrics.bad_mapq, 0); + let counts = all_counts.counts.get(&counts_key).unwrap(); + assert_eq!(counts.forward, 3); + assert_eq!(counts.reverse, 0); + + // Test Paired-End read. Evidence for Reverse Strandedness. + let mut read = sam::alignment::Record::default(); + read.flags_mut().set(0x1.into(), true); + read.flags_mut().set(0x40.into(), true); + read.data_mut().insert(Tag::ReadGroup, rg_tag.clone()); + classify_read(&read, &Strand::Reverse, &mut all_counts, &mut read_metrics); + assert_eq!(read_metrics.paired_end_reads, 3); + assert_eq!(read_metrics.single_end_reads, 1); + assert_eq!(read_metrics.filtered_by_flags, 0); + assert_eq!(read_metrics.bad_mapq, 0); + let counts = all_counts.counts.get(&counts_key).unwrap(); + assert_eq!(counts.forward, 3); + assert_eq!(counts.reverse, 1); + } + + #[test] + fn test_predict_strandedness() { + let counts = Counts { + forward: 10, + reverse: 90, + }; + let result = predict_strandedness("rg1", &counts); + assert!(result.succeeded); + assert_eq!(result.strandedness, Some("Reverse".to_string())); + assert_eq!(result.forward, 10); + assert_eq!(result.reverse, 90); + assert_eq!(result.forward_pct, 10.0); + assert_eq!(result.reverse_pct, 90.0); + + let counts = Counts { + forward: 50, + reverse: 50, + }; + let result = predict_strandedness("rg1", &counts); + assert!(result.succeeded); + assert_eq!(result.strandedness, Some("Unstranded".to_string())); + assert_eq!(result.forward, 50); + assert_eq!(result.reverse, 50); + assert_eq!(result.forward_pct, 50.0); + assert_eq!(result.reverse_pct, 50.0); + + let counts = Counts { + forward: 90, + reverse: 10, + }; + let result = predict_strandedness("rg1", &counts); + assert!(result.succeeded); + assert_eq!(result.strandedness, Some("Forward".to_string())); + assert_eq!(result.forward, 90); + assert_eq!(result.reverse, 10); + assert_eq!(result.forward_pct, 90.0); + assert_eq!(result.reverse_pct, 10.0); + + let counts = Counts { + forward: 0, + reverse: 0, + }; + let result = predict_strandedness("rg1", &counts); + assert!(!result.succeeded); + assert_eq!(result.strandedness, None); + assert_eq!(result.forward, 0); + assert_eq!(result.reverse, 0); + assert_eq!(result.forward_pct, 0.0); + assert_eq!(result.reverse_pct, 0.0); + } +} diff --git a/src/derive/strandedness/results.rs b/src/derive/strandedness/results.rs new file mode 100644 index 0000000..3082ba8 --- /dev/null +++ b/src/derive/strandedness/results.rs @@ -0,0 +1,196 @@ +//! Results structs for the strandedness subcommand. + +use serde::Serialize; + +/// General read record metrics that are tallied as a part of the +/// strandedness subcommand. +#[derive(Clone, Default, Serialize, Debug)] +pub struct ReadRecordMetrics { + /// The number of records that have been filtered because of their flags. + /// (i.e. they were qc_fail, duplicates, secondary, or supplementary) + /// These conditions can be toggled on/off with CL flags + pub filtered_by_flags: usize, + + /// The number of records that have been ignored because they failed the MAPQ filter. + pub bad_mapq: usize, + + /// The number of records determined to be Paired-End. + pub paired_end_reads: usize, + + /// The number of records determined to be Single-End. + pub single_end_reads: usize, +} + +/// General gene metrics that are tallied as a part of the +/// strandedness subcommand. +#[derive(Clone, Default, Serialize, Debug)] +pub struct GeneRecordMetrics { + /// The total number of genes found in the GFF. + pub total: usize, + + /// The number of genes that were found to be protein coding. + /// If --all-genes is set this will not be tallied. + pub protein_coding: usize, + + /// The number of genes which were discarded due to having + /// an unknown or missing strand. + pub bad_strand: usize, + + /// The number of genes randomly selected for evaluation. + pub evaluated: usize, + + /// The number of genes which were discarded due to having + /// mixed strands (the gene has exons on both strands). + pub mixed_strands: usize, + + /// The number of genes which were discarded due to not having + /// enough reads. + pub not_enough_reads: usize, + + /// The number of genes considered for strandedness evidence. + pub considered: usize, +} + +/// General exon metrics that are tallied as a part of the +/// strandedness subcommand. +#[derive(Clone, Default, Serialize, Debug)] +pub struct ExonRecordMetrics { + /// The total number of exons found in the GFF. + pub total: usize, + + /// The number of exons discarded due to having an unknown/invalid strand. + pub bad_strand: usize, +} + +/// Struct for managing record tracking. +#[derive(Clone, Default, Debug)] +pub struct RecordTracker { + /// Read metrics. + pub reads: ReadRecordMetrics, + + /// Gene metrics. + pub genes: GeneRecordMetrics, + + /// Exon metrics. + pub exons: ExonRecordMetrics, +} + +/// Struct holding the per read group results for an `ngs derive strandedness` +/// subcommand call. +#[derive(Debug, Serialize)] +pub struct ReadGroupDerivedStrandednessResult { + /// Name of the read group. + pub read_group: String, + + /// Whether or not strandedness was determined for this read group. + pub succeeded: bool, + + /// The strandedness of this read group, if derivable. + pub strandedness: Option, + + /// The total number of reads in this read group. + pub total: usize, + + /// The number of reads that are evidence of Forward Strandedness. + pub forward: usize, + + /// The number of reads that are evidence of Reverse Strandedness. + pub reverse: usize, + + /// The percent of evidence for Forward Strandedness. + pub forward_pct: f64, + + /// The percent of evidence for Reverse Strandedness. + pub reverse_pct: f64, +} + +impl ReadGroupDerivedStrandednessResult { + /// Creates a new [`ReadGroupDerivedStrandednessResult`]. + pub fn new( + read_group: String, + succeeded: bool, + strandedness: Option, + forward: usize, + reverse: usize, + ) -> Self { + ReadGroupDerivedStrandednessResult { + read_group, + succeeded, + strandedness, + total: forward + reverse, + forward, + reverse, + forward_pct: (forward as f64 / (forward + reverse) as f64) * 100.0, + reverse_pct: (reverse as f64 / (forward + reverse) as f64) * 100.0, + } + } +} + +/// Struct holding the final results for an `ngs derive strandedness` subcommand +/// call. +#[derive(Debug, Serialize)] +pub struct DerivedStrandednessResult { + /// Whether or not the `ngs derive strandedness` subcommand succeeded. + pub succeeded: bool, + + /// The strandedness of this read group, if derivable. + pub strandedness: Option, + + /// The total number of reads. + pub total: usize, + + /// The number of reads that are evidence of Forward Strandedness. + pub forward: usize, + + /// The number of reads that are evidence of Reverse Strandedness. + pub reverse: usize, + + /// The percent of evidence for Forward Strandedness. + pub forward_pct: f64, + + /// The percent of evidence for Reverse Strandedness. + pub reverse_pct: f64, + + /// Vector of [`ReadGroupDerivedStrandednessResult`]s. + /// One for each read group in the BAM, + /// and potentially one for any reads with an unknown read group. + pub read_groups: Vec, + + /// General read record metrics that are tallied as a part of the + /// strandedness subcommand. + pub read_metrics: ReadRecordMetrics, + + /// General gene metrics that are tallied as a part of the + /// strandedness subcommand. + pub gene_metrics: GeneRecordMetrics, + + /// General exon metrics that are tallied as a part of the + /// strandedness subcommand. + pub exon_metrics: ExonRecordMetrics, +} + +impl DerivedStrandednessResult { + /// Creates a new [`DerivedStrandednessResult`]. + pub fn new( + succeeded: bool, + strandedness: Option, + forward: usize, + reverse: usize, + read_groups: Vec, + metrics: RecordTracker, + ) -> Self { + DerivedStrandednessResult { + succeeded, + strandedness, + total: forward + reverse, + forward, + reverse, + forward_pct: (forward as f64 / (forward + reverse) as f64) * 100.0, + reverse_pct: (reverse as f64 / (forward + reverse) as f64) * 100.0, + read_groups, + read_metrics: metrics.reads, + gene_metrics: metrics.genes, + exon_metrics: metrics.exons, + } + } +} diff --git a/src/generate/command.rs b/src/generate/command.rs index a6acc77..f106d30 100644 --- a/src/generate/command.rs +++ b/src/generate/command.rs @@ -12,21 +12,9 @@ use tracing::info; use crate::generate::providers::reference_provider::ReferenceGenomeSequenceProvider; use crate::generate::providers::SequenceProvider; +use crate::utils::args::arg_in_range as error_rate_in_range; use crate::utils::formats; -/// Utility method to parse the error rate passed in on the command line and -/// ensure the rate is within the range [0.0, 1.0]. -pub fn error_rate_in_range(error_rate_raw: &str) -> Result { - let error_rate: f32 = error_rate_raw - .parse() - .map_err(|_| format!("{} isn't a float", error_rate_raw))?; - - match (0.0..=1.0).contains(&error_rate) { - true => Ok(error_rate), - false => Err(String::from("Error rate must be between 0.0 and 1.0")), - } -} - /// Command line arguments for `ngs generate`. #[derive(Args)] #[command(group(ArgGroup::new("record-count").required(true).args(["coverage", "num_records"])))] @@ -42,9 +30,8 @@ pub struct GenerateArgs { reference_providers: Vec, /// The error rate for the sequencer as a fraction between [0.0, 1.0] (per base). - #[arg(short, long, value_name = "F32", default_value = "0.0001")] - #[arg(value_parser = error_rate_in_range)] - error_rate: Option, + #[arg(short, long, value_name = "F64", default_value = "0.0001")] + error_rate: f64, /// Specifies the number of records to generate. #[arg(short, long, value_name = "USIZE", conflicts_with = "coverage")] @@ -58,6 +45,9 @@ pub struct GenerateArgs { /// Main function for the `ngs generate` subcommand. pub fn generate(args: GenerateArgs) -> anyhow::Result<()> { // (0) Parse arguments needed for subcommand. + let _error_rate = error_rate_in_range(args.error_rate, 0.0..=1.0) + .with_context(|| "Error rate is not within acceptable range")?; + let result: anyhow::Result> = args .reference_providers .iter() diff --git a/src/index/bam.rs b/src/index/bam.rs index d9e3b4a..c663ac1 100644 --- a/src/index/bam.rs +++ b/src/index/bam.rs @@ -69,7 +69,7 @@ pub fn index(src: PathBuf) -> anyhow::Result<()> { let mut builder = bai::Index::builder(); let mut start_position = reader.virtual_position(); - let mut counter = RecordCounter::new(); + let mut counter = RecordCounter::default(); loop { match reader.read_record(&header.parsed, &mut record) { diff --git a/src/main.rs b/src/main.rs index 1441fef..6f5ced0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -89,9 +89,24 @@ fn main() -> anyhow::Result<()> { match cli.subcommand { Subcommands::Convert(args) => convert::command::convert(args)?, Subcommands::Derive(args) => match args.subcommand { + derive::command::DeriveSubcommand::Encoding(args) => { + derive::command::encoding::derive(args)? + } + derive::command::DeriveSubcommand::Endedness(args) => { + derive::command::endedness::derive(args)? + } derive::command::DeriveSubcommand::Instrument(args) => { derive::command::instrument::derive(args)? } + derive::command::DeriveSubcommand::Readlen(args) => { + derive::command::readlen::derive(args)? + } + derive::command::DeriveSubcommand::Strandedness(args) => { + derive::command::strandedness::derive(args)? + } + derive::command::DeriveSubcommand::JunctionAnnotation(args) => { + derive::command::junction_annotation::derive(args)? + } }, Subcommands::Generate(args) => generate::command::generate(args)?, Subcommands::Index(args) => index::command::index(args)?, diff --git a/src/qc/command.rs b/src/qc/command.rs index eeaa29c..6fac58d 100644 --- a/src/qc/command.rs +++ b/src/qc/command.rs @@ -50,8 +50,13 @@ pub struct QcArgs { /// to process per sequence in the second pass. /// /// This is generally only used for testing purposes. - #[arg(short = 'n', long, value_name = "USIZE")] - num_records: Option, + #[arg( + short, + long, + default_value_t, + value_name = "'all' or a positive, non-zero integer" + )] + num_records: NumberOfRecords, /// Directory to output files to. Defaults to current working directory. #[arg(short = 'o', long, value_name = "PATH")] @@ -201,8 +206,6 @@ pub fn qc(args: QcArgs) -> anyhow::Result<()> { // Number of Records // //===================// - let num_records = NumberOfRecords::from(args.num_records); - app( src, reference_fasta, @@ -210,7 +213,7 @@ pub fn qc(args: QcArgs) -> anyhow::Result<()> { reference_genome, output_prefix, output_directory, - num_records, + args.num_records, feature_names, only_facet, vaf_file_path, @@ -300,7 +303,7 @@ fn app( //====================================================================// info!("Starting first pass for QC stats."); - let mut counter = RecordCounter::new(); + let mut counter = RecordCounter::default(); for result in reader.records(&header.parsed) { let record = result?; @@ -349,9 +352,9 @@ fn app( info!("Starting second pass for QC stats."); let mut reader = File::open(&src).map(bam::Reader::new)?; let index = - bai::read(&src.with_extension("bam.bai")).with_context(|| "reading BAM index")?; + bai::read(src.with_extension("bam.bai")).with_context(|| "reading BAM index")?; - let mut counter = RecordCounter::new(); + let mut counter = RecordCounter::default(); for (name, seq) in header.parsed.reference_sequences() { let start = Position::MIN; diff --git a/src/utils.rs b/src/utils.rs index 9a33a4e..8f0207c 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -8,3 +8,4 @@ pub mod formats; pub mod genome; pub mod histogram; pub mod pathbuf; +pub mod read_groups; diff --git a/src/utils/alignment.rs b/src/utils/alignment.rs index 1a731f1..678114f 100644 --- a/src/utils/alignment.rs +++ b/src/utils/alignment.rs @@ -1,11 +1,25 @@ //! Utilities related to alignment of sequences. use anyhow::bail; -use noodles::sam::record::{cigar::op::Kind, sequence::Base, Cigar}; +use noodles::sam::record::{cigar::op::Kind, sequence::Base, Cigar, MappingQuality}; use super::cigar::consumes_reference; use super::cigar::consumes_sequence; +/// Filter an alignment record by its mapping quality. `true` means "filter the record" and `false` means "do not filter the record". +pub fn filter_by_mapq( + record: &noodles::sam::alignment::Record, + min_mapq: Option, +) -> bool { + match min_mapq { + Some(min_mapq) => match record.mapping_quality() { + Some(mapq) => mapq.get() < min_mapq.get(), + None => true, // Filter if no MAPQ is present + }, + None => false, // Do not filter if no min MAPQ is specified + } +} + /// Turns a condensed Cigar representation into a flattened representation. For /// example, 10M will become a Vec of length 10 comprised completely of /// Kind::MATCH. This utility is useful for generating a representation of a @@ -127,9 +141,36 @@ impl<'a> ReferenceRecordStepThrough<'a> { #[cfg(test)] mod tests { - use noodles::sam::record::{Cigar, Sequence}; + use noodles::sam::record::{Cigar, MappingQuality, Sequence}; + + use super::*; - use super::ReferenceRecordStepThrough; + #[test] + pub fn it_filters_by_mapq() -> anyhow::Result<()> { + let mut record = noodles::sam::alignment::Record::default(); + assert!(filter_by_mapq( + &record, + Some(MappingQuality::new(0).unwrap()) + )); // Get filtered because MAPQ is missing + assert!(!filter_by_mapq(&record, None)); // Do not get filtered because filter is disabled + + record + .mapping_quality_mut() + .replace(MappingQuality::new(10).unwrap()); + assert!(!filter_by_mapq( + &record, + Some(MappingQuality::new(0).unwrap()) + )); // Do not get filtered because MAPQ is present + assert!(!filter_by_mapq( + &record, + Some(MappingQuality::new(1).unwrap()) + )); // Do not get filtered because MAPQ is greater than 1 + assert!(filter_by_mapq( + &record, + Some(MappingQuality::new(11).unwrap()) + )); // Do get filtered because MAPQ is less than 11 + Ok(()) + } #[test] pub fn it_correctly_returns_zero_edits_when_sequences_are_identical() -> anyhow::Result<()> { diff --git a/src/utils/args.rs b/src/utils/args.rs index e576400..ac6f309 100644 --- a/src/utils/args.rs +++ b/src/utils/args.rs @@ -1,6 +1,7 @@ //! Utilities related to the parsing of arguments. use std::fmt::Display; +use std::num::NonZeroUsize; use noodles::bgzf::writer::CompressionLevel; use tracing::debug; @@ -11,17 +12,76 @@ use tracing::debug; /// Utility enum to designate whether we are reviewing all records in the file /// or just some of them. +#[derive(Clone, Debug)] pub enum NumberOfRecords { /// Designates that we should review _all_ of the records in the file. All, /// Designates that we should review _some_ of the records in the file. The /// exact count of records is stored in the `usize`. - Some(usize), + Some(NonZeroUsize), } -impl From> for NumberOfRecords { - fn from(num_records: Option) -> Self { +impl std::default::Default for NumberOfRecords { + fn default() -> Self { + Self::All + } +} + +impl std::fmt::Display for NumberOfRecords { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + NumberOfRecords::All => write!(f, "all"), + NumberOfRecords::Some(value) => write!(f, "{value}"), + } + } +} + +/// An error type for parsing the number of records. +#[derive(Debug)] +pub enum NumberOfRecordsError { + /// The number of records is invalid. + Invalid(String), +} + +impl std::fmt::Display for NumberOfRecordsError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + NumberOfRecordsError::Invalid(value) => write!(f, "invalid number of reads: {value}"), + } + } +} + +impl std::error::Error for NumberOfRecordsError {} + +impl std::str::FromStr for NumberOfRecords { + type Err = NumberOfRecordsError; + + fn from_str(s: &str) -> Result { + match s.to_lowercase().as_str() { + "all" => Ok(NumberOfRecords::All), + _ => s + .parse::() + .map_err(|_| { + NumberOfRecordsError::Invalid(String::from( + "must be a positive, non-zero integer or 'all'", + )) + }) + .and_then(|num| { + NonZeroUsize::new(num) + .ok_or_else(|| { + NumberOfRecordsError::Invalid(String::from( + "integers must be positive and non-zero", + )) + }) + .map(NumberOfRecords::Some) + }), + } + } +} + +impl From> for NumberOfRecords { + fn from(num_records: Option) -> Self { match num_records { Some(n) => { debug!("Reading a maximum of {} records.", n); @@ -73,3 +133,20 @@ impl From for CompressionLevel { } } } + +//=============// +// Arg Parsers // +//=============// + +/// Utility method to parse command line floats and ensure they are +/// within the range [MIN, MAX]. +pub fn arg_in_range(arg: f64, range: std::ops::RangeInclusive) -> anyhow::Result { + match range.contains(&arg) { + true => Ok(arg), + false => anyhow::bail!( + "Value must be between {} and {}", + range.start(), + range.end() + ), + } +} diff --git a/src/utils/display.rs b/src/utils/display.rs index 2a58111..1ffbbca 100644 --- a/src/utils/display.rs +++ b/src/utils/display.rs @@ -24,29 +24,46 @@ impl fmt::Display for PercentageFormat { } /// Utility struct used to uniformly count and report the number of records processed. -#[derive(Default)] -pub struct RecordCounter(usize); +pub struct RecordCounter { + /// The number of records processed. + count: usize, + + /// The number of records to log every. + log_every: usize, +} + +impl Default for RecordCounter { + fn default() -> Self { + RecordCounter { + count: 0, + log_every: 1_000_000, + } + } +} impl RecordCounter { /// Creates a new `RecordCounter`. - pub fn new() -> Self { - Self::default() + pub fn new(log_every: Option) -> Self { + RecordCounter { + count: 0, + log_every: log_every.unwrap_or(1_000_000), + } } /// Gets the current number of records counted via a copy. pub fn get(&self) -> usize { - self.0 + self.count } /// Increments the counter and reports the number of records processed (if /// appropriate). pub fn inc(&mut self) { - self.0 += 1; + self.count += 1; - if self.0 % 1_000_000 == 0 { + if self.count % self.log_every == 0 { info!( " [*] Processed {} records.", - self.0.to_formatted_string(&Locale::en), + self.count.to_formatted_string(&Locale::en), ); } } @@ -57,7 +74,7 @@ impl RecordCounter { /// (if it exists, otherwise it loops forever). pub fn time_to_break(&self, limit: &NumberOfRecords) -> bool { match limit { - NumberOfRecords::Some(v) => self.0 >= *v, + NumberOfRecords::Some(v) => self.count >= ::from(*v), NumberOfRecords::All => false, } } diff --git a/src/utils/read_groups.rs b/src/utils/read_groups.rs new file mode 100644 index 0000000..86caf5d --- /dev/null +++ b/src/utils/read_groups.rs @@ -0,0 +1,75 @@ +//! This module contains functions to validate the read group information in the header and the records. + +use lazy_static::lazy_static; +use noodles::sam::alignment::Record; +use noodles::sam::header; +use noodles::sam::record::data::field::Tag; +use std::collections::HashSet; +use std::sync::Arc; +use tracing::warn; + +/// Type alias for a read group pointer. +pub type ReadGroupPtr = Arc; + +lazy_static! { + /// String used to represent an unknown read group. Wrapped in an Arc to prevent redundant memory usage. + pub static ref UNKNOWN_READ_GROUP: ReadGroupPtr = Arc::new(String::from("unknown_read_group")); +} + +/// Returns the read group tag from the record. +/// If the read group is not found in the record, the read group is set to "unknown_read_group". +/// TODO: Revisit this logic +pub fn get_read_group( + record: &Record, + found_rgs: Option<&mut HashSet>, +) -> ReadGroupPtr { + match (record.data().get(Tag::ReadGroup), found_rgs) { + (Some(rg), Some(read_groups)) => { + let rg = rg.to_string(); + if !read_groups.contains(&rg) { + read_groups.insert(Arc::new(rg.clone())); + } + Arc::clone(read_groups.get(&rg).unwrap()) + } + (Some(rg), None) => Arc::new(rg.to_string()), + (None, _) => Arc::clone(&UNKNOWN_READ_GROUP), + } +} + +/// Compares the read group tags found in the records +/// and the read groups found in the header. +/// Returns a vector of read group names that were found in the header +/// but not in the records. +pub fn validate_read_group_info( + found_rgs: &HashSet, + header: &header::Header, +) -> Vec { + let mut rgs_in_header_not_records = Vec::new(); + let mut rgs_in_records_not_header = Vec::new(); + + for (rg_id, _) in header.read_groups() { + if !found_rgs.contains(rg_id) { + rgs_in_header_not_records.push(rg_id.to_string()); + } + } + if !rgs_in_header_not_records.is_empty() { + warn!( + "The following read groups were not found in the file: {:?}", + rgs_in_header_not_records + ); + } + + for rg_id in found_rgs { + if !header.read_groups().contains_key(rg_id.as_str()) { + rgs_in_records_not_header.push(rg_id.to_string()); + } + } + if !rgs_in_records_not_header.is_empty() { + warn!( + "The following read groups were not found in the header: {:?}", + rgs_in_records_not_header + ); + } + + rgs_in_header_not_records +}