diff --git a/packages/nextclade/src/align/seed_match.rs b/packages/nextclade/src/align/seed_match.rs index 0888b066b..bb50363ea 100644 --- a/packages/nextclade/src/align/seed_match.rs +++ b/packages/nextclade/src/align/seed_match.rs @@ -16,7 +16,7 @@ use std::cmp::{max, min}; use std::collections::{BTreeMap, VecDeque}; /// Copied from https://stackoverflow.com/a/75084739/7483211 -struct SkipEvery { +pub struct SkipEvery { inner: I, every: usize, index: usize, @@ -44,7 +44,7 @@ impl Iterator for SkipEvery { } } -trait IteratorSkipEveryExt: Iterator + Sized { +pub trait IteratorSkipEveryExt: Iterator + Sized { fn skip_every(self, every: usize) -> SkipEvery { SkipEvery::new(self, every) } diff --git a/packages/nextclade/src/sort/minimizer_index.rs b/packages/nextclade/src/sort/minimizer_index.rs index 4bdcefd7d..1c40df7dc 100644 --- a/packages/nextclade/src/sort/minimizer_index.rs +++ b/packages/nextclade/src/sort/minimizer_index.rs @@ -14,7 +14,7 @@ pub const MINIMIZER_INDEX_SCHEMA_VERSION_FROM: &str = "3.0.0"; pub const MINIMIZER_INDEX_SCHEMA_VERSION_TO: &str = "3.0.0"; pub const MINIMIZER_INDEX_ALGO_VERSION: &str = "1"; -pub type MinimizerMap = BTreeMap>; +pub type MinimizerMap = BTreeMap>; /// Contains external configuration and data specific for a particular pathogen #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] @@ -56,7 +56,7 @@ pub fn serde_deserialize_minimizers<'de, D: Deserializer<'de>>(deserializer: D) let res = map .into_iter() - .map(|(k, v)| Ok((u64::from_str(&k)?, v))) + .map(|(k, v)| Ok((u32::from_str(&k)?, v))) .collect::>() .map_err(serde::de::Error::custom)?; diff --git a/packages/nextclade/src/sort/minimizer_search.rs b/packages/nextclade/src/sort/minimizer_search.rs index baa3dc865..2cd2252a2 100644 --- a/packages/nextclade/src/sort/minimizer_search.rs +++ b/packages/nextclade/src/sort/minimizer_search.rs @@ -1,3 +1,4 @@ +use crate::align::seed_match::IteratorSkipEveryExt; use crate::io::fasta::FastaRecord; use crate::sort::minimizer_index::{MinimizerIndexJson, MinimizerIndexParams}; use crate::sort::params::NextcladeSeqSortParams; @@ -10,7 +11,7 @@ use log::debug; use ordered_float::OrderedFloat; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; -use std::collections::{BTreeMap, BTreeSet}; +use std::collections::{BTreeMap, BTreeSet, HashSet}; #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] #[serde(rename_all = "camelCase")] @@ -44,17 +45,7 @@ pub fn run_minimizer_search( ) -> Result { let n_refs = index.references.len(); - let minimizers = get_ref_search_minimizers(fasta_record, &index.params); - let mut hit_counts = vec![0; n_refs]; - for m in minimizers { - if let Some(mz) = index.minimizers.get(&m) { - for (ri, hit_count) in hit_counts.iter_mut().enumerate() { - if mz.contains(&ri) { - *hit_count += 1; - } - } - } - } + let hit_counts = calculate_minimizer_hits(fasta_record, index, n_refs); // we expect hits to be proportional to the length of the sequence and the number of minimizers per reference let mut scores: Vec = vec![0.0; hit_counts.len()]; @@ -231,68 +222,94 @@ pub struct DatasetSuggestionStats { pub qry_indices: Vec, } -const fn invertible_hash(x: u64) -> u64 { - let m: u64 = (1 << 32) - 1; - let mut x: u64 = (!x).wrapping_add(x << 21) & m; +const fn invertible_hash(x: u32) -> u32 { + let mut x = (!x).wrapping_add(x << 21); x = x ^ (x >> 24); - x = (x + (x << 3) + (x << 8)) & m; + x = x.wrapping_add(x << 3).wrapping_add(x << 8); x = x ^ (x >> 14); - x = (x + (x << 2) + (x << 4)) & m; + x = x.wrapping_add(x << 2).wrapping_add(x << 4); x = x ^ (x >> 28); - x = (x + (x << 31)) & m; + x = x.wrapping_add(x << 31); x } -fn get_hash(kmer: &[u8], params: &MinimizerIndexParams) -> u64 { - let cutoff = params.cutoff as u64; - - let mut x = 0; - let mut j = 0; - - for (i, nuc) in kmer.iter().enumerate() { - let nuc = *nuc as char; - - if i % 3 == 2 { - continue; // skip every third nucleotide to pick up conserved patterns - } - - if !"ACGT".contains(nuc) { - return cutoff + 1; // break out of loop, return hash above cutoff - } - - // A=11=3, C=10=2, G=00=0, T=01=1 - if "AC".contains(nuc) { - x += 1 << j; - } - - if "AT".contains(nuc) { - x += 1 << (j + 1); +const INVALID_NUCLEOTIDE_VALUE: u8 = 4; // Sentinel value for invalid nucleotides + +const NUCLEOTIDE_LOOKUP: [u8; 256] = { + let mut table = [INVALID_NUCLEOTIDE_VALUE; 256]; // Use sentinel value 4 for invalid nucleotides + table[b'A' as usize] = 0b11; // A=11=3 + table[b'a' as usize] = 0b11; // a=11=3 + table[b'T' as usize] = 0b10; // T=10=2 + table[b't' as usize] = 0b10; // t=10=2 + table[b'G' as usize] = 0b00; // G=00=0 + table[b'g' as usize] = 0b00; // g=00=0 + table[b'C' as usize] = 0b01; // C=01=1 + table[b'c' as usize] = 0b01; // c=01=1 + table +}; + +// Expects bit-encoded kmer where each nucleotide is represented by 2 bits +// A=11, C=10, G=00, T=01 and invalid nucleotides are represented by INVALID_NUCLEOTIDE_VALUE +fn get_hash(kmer: &[u8], params: &MinimizerIndexParams) -> u32 { + let cutoff = params.cutoff as u32; + + let mut x: u32 = 0; + let mut j: u8 = 0; + + // Skip every third nucleotide to pick up conserved patterns + for &bits in kmer.iter().skip_every(3) { + if bits == INVALID_NUCLEOTIDE_VALUE { + return cutoff + 1; // invalid nucleotide } + x |= (bits as u32) << j; j += 2; } invertible_hash(x) } -pub fn get_ref_search_minimizers(seq: &FastaRecord, params: &MinimizerIndexParams) -> Vec { - let k = params.k as usize; - let cutoff = params.cutoff as u64; - - let seq_str = preprocess_seq(&seq.seq); - let n = seq_str.len().saturating_sub(k); - let mut minimizers = Vec::with_capacity(n); - for i in 0..n { - let kmer = &seq_str.as_bytes()[i..i + k]; - let mhash = get_hash(kmer, params); - // accept only hashes below cutoff --> reduces the size of the index and the number of lookups - if mhash < cutoff { - minimizers.push(mhash); - } - } - minimizers.into_iter().unique().collect_vec() +pub fn calculate_minimizer_hits( + fasta_record: &FastaRecord, + index: &MinimizerIndexJson, + n_refs: usize, +) -> Vec { + let params = &index.params; + let k = params.k as usize; + let cutoff = params.cutoff as u32; + + let expected_n_minimizers = fasta_record.seq.len() as f64 * cutoff as f64 / (1_u64 << 32) as f64; + let mut seen: HashSet<_> = HashSet::with_capacity(expected_n_minimizers as usize); + + let seq_str = preprocess_seq(&fasta_record.seq); + + seq_str.windows(k) + .map(|kmer| get_hash(kmer, params)) + .filter(|&mhash| mhash < cutoff && seen.insert(mhash)) // Faster than .unique() + .fold(vec![0; n_refs], |mut acc, m| { + if let Some(locations) = index.minimizers.get(&m) { + for &ref_idx in locations { + if let Some(count) = acc.get_mut(ref_idx) { + *count += 1; + } + } + } + acc + }) } -fn preprocess_seq(seq: impl AsRef) -> String { - seq.as_ref().to_uppercase().replace('-', "") +// Create a bit-packed representation of the kmer +// where each nucleotide is represented by 2 bits: +// A=11, C=10, G=00, T=01 +// Invalid nucleotides are represented by INVALID_NUCLEOTIDE_VALUE +fn preprocess_seq(seq: &str) -> Vec{ + let mut result = Vec::with_capacity(seq.len()); + result.extend(seq.bytes().filter_map(|b| { + if b == b'-' { + None + } else { + Some(NUCLEOTIDE_LOOKUP[b as usize]) + } + })); + result }