Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions packages/nextclade/src/align/seed_match.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use std::cmp::{max, min};
use std::collections::{BTreeMap, VecDeque};

/// Copied from https://stackoverflow.com/a/75084739/7483211
struct SkipEvery<I> {
pub struct SkipEvery<I> {
inner: I,
every: usize,
index: usize,
Expand Down Expand Up @@ -44,7 +44,7 @@ impl<I: Iterator> Iterator for SkipEvery<I> {
}
}

trait IteratorSkipEveryExt: Iterator + Sized {
pub trait IteratorSkipEveryExt: Iterator + Sized {
fn skip_every(self, every: usize) -> SkipEvery<Self> {
SkipEvery::new(self, every)
}
Expand Down
4 changes: 2 additions & 2 deletions packages/nextclade/src/sort/minimizer_index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ pub const MINIMIZER_INDEX_SCHEMA_VERSION_FROM: &str = "3.0.0";
pub const MINIMIZER_INDEX_SCHEMA_VERSION_TO: &str = "3.0.0";
pub const MINIMIZER_INDEX_ALGO_VERSION: &str = "1";

pub type MinimizerMap = BTreeMap<u64, Vec<usize>>;
pub type MinimizerMap = BTreeMap<u32, Vec<usize>>;

/// Contains external configuration and data specific for a particular pathogen
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
Expand Down Expand Up @@ -56,7 +56,7 @@ pub fn serde_deserialize_minimizers<'de, D: Deserializer<'de>>(deserializer: D)

let res = map
.into_iter()
.map(|(k, v)| Ok((u64::from_str(&k)?, v)))
.map(|(k, v)| Ok((u32::from_str(&k)?, v)))
.collect::<Result<MinimizerMap, Report>>()
.map_err(serde::de::Error::custom)?;

Expand Down
137 changes: 77 additions & 60 deletions packages/nextclade/src/sort/minimizer_search.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::align::seed_match::IteratorSkipEveryExt;
use crate::io::fasta::FastaRecord;
use crate::sort::minimizer_index::{MinimizerIndexJson, MinimizerIndexParams};
use crate::sort::params::NextcladeSeqSortParams;
Expand All @@ -10,7 +11,7 @@ use log::debug;
use ordered_float::OrderedFloat;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::collections::{BTreeMap, BTreeSet};
use std::collections::{BTreeMap, BTreeSet, HashSet};

#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "camelCase")]
Expand Down Expand Up @@ -44,17 +45,7 @@ pub fn run_minimizer_search(
) -> Result<MinimizerSearchResult, Report> {
let n_refs = index.references.len();

let minimizers = get_ref_search_minimizers(fasta_record, &index.params);
let mut hit_counts = vec![0; n_refs];
for m in minimizers {
if let Some(mz) = index.minimizers.get(&m) {
for (ri, hit_count) in hit_counts.iter_mut().enumerate() {
if mz.contains(&ri) {
*hit_count += 1;
}
}
}
}
let hit_counts = calculate_minimizer_hits(fasta_record, index, n_refs);

// we expect hits to be proportional to the length of the sequence and the number of minimizers per reference
let mut scores: Vec<f64> = vec![0.0; hit_counts.len()];
Expand Down Expand Up @@ -231,68 +222,94 @@ pub struct DatasetSuggestionStats {
pub qry_indices: Vec<usize>,
}

const fn invertible_hash(x: u64) -> u64 {
let m: u64 = (1 << 32) - 1;
let mut x: u64 = (!x).wrapping_add(x << 21) & m;
const fn invertible_hash(x: u32) -> u32 {
let mut x = (!x).wrapping_add(x << 21);
x = x ^ (x >> 24);
x = (x + (x << 3) + (x << 8)) & m;
x = x.wrapping_add(x << 3).wrapping_add(x << 8);
x = x ^ (x >> 14);
x = (x + (x << 2) + (x << 4)) & m;
x = x.wrapping_add(x << 2).wrapping_add(x << 4);
x = x ^ (x >> 28);
x = (x + (x << 31)) & m;
x = x.wrapping_add(x << 31);
x
}

fn get_hash(kmer: &[u8], params: &MinimizerIndexParams) -> u64 {
let cutoff = params.cutoff as u64;

let mut x = 0;
let mut j = 0;

for (i, nuc) in kmer.iter().enumerate() {
let nuc = *nuc as char;

if i % 3 == 2 {
continue; // skip every third nucleotide to pick up conserved patterns
}

if !"ACGT".contains(nuc) {
return cutoff + 1; // break out of loop, return hash above cutoff
}

// A=11=3, C=10=2, G=00=0, T=01=1
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment was actually wrong - T and C should have been interchanged.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's why the initial preview finds no matches (I used the comment rather than the code to implement the lookup table)

if "AC".contains(nuc) {
x += 1 << j;
}

if "AT".contains(nuc) {
x += 1 << (j + 1);
const INVALID_NUCLEOTIDE_VALUE: u8 = 4; // Sentinel value for invalid nucleotides

const NUCLEOTIDE_LOOKUP: [u8; 256] = {
let mut table = [INVALID_NUCLEOTIDE_VALUE; 256]; // Use sentinel value 4 for invalid nucleotides
table[b'A' as usize] = 0b11; // A=11=3
table[b'a' as usize] = 0b11; // a=11=3
table[b'T' as usize] = 0b10; // T=10=2
table[b't' as usize] = 0b10; // t=10=2
table[b'G' as usize] = 0b00; // G=00=0
table[b'g' as usize] = 0b00; // g=00=0
table[b'C' as usize] = 0b01; // C=01=1
table[b'c' as usize] = 0b01; // c=01=1
table
};

// Expects bit-encoded kmer where each nucleotide is represented by 2 bits
// A=11, C=10, G=00, T=01 and invalid nucleotides are represented by INVALID_NUCLEOTIDE_VALUE
fn get_hash(kmer: &[u8], params: &MinimizerIndexParams) -> u32 {
let cutoff = params.cutoff as u32;

let mut x: u32 = 0;
let mut j: u8 = 0;

// Skip every third nucleotide to pick up conserved patterns
for &bits in kmer.iter().skip_every(3) {
if bits == INVALID_NUCLEOTIDE_VALUE {
return cutoff + 1; // invalid nucleotide
}

x |= (bits as u32) << j;
j += 2;
}

invertible_hash(x)
}

pub fn get_ref_search_minimizers(seq: &FastaRecord, params: &MinimizerIndexParams) -> Vec<u64> {
let k = params.k as usize;
let cutoff = params.cutoff as u64;

let seq_str = preprocess_seq(&seq.seq);
let n = seq_str.len().saturating_sub(k);
let mut minimizers = Vec::with_capacity(n);
for i in 0..n {
let kmer = &seq_str.as_bytes()[i..i + k];
let mhash = get_hash(kmer, params);
// accept only hashes below cutoff --> reduces the size of the index and the number of lookups
if mhash < cutoff {
minimizers.push(mhash);
}
}
minimizers.into_iter().unique().collect_vec()
pub fn calculate_minimizer_hits(
fasta_record: &FastaRecord,
index: &MinimizerIndexJson,
n_refs: usize,
) -> Vec<u64> {
let params = &index.params;
let k = params.k as usize;
let cutoff = params.cutoff as u32;

let expected_n_minimizers = fasta_record.seq.len() as f64 * cutoff as f64 / (1_u64 << 32) as f64;
let mut seen: HashSet<_> = HashSet::with_capacity(expected_n_minimizers as usize);

let seq_str = preprocess_seq(&fasta_record.seq);

seq_str.windows(k)
.map(|kmer| get_hash(kmer, params))
.filter(|&mhash| mhash < cutoff && seen.insert(mhash)) // Faster than .unique()
.fold(vec![0; n_refs], |mut acc, m| {
if let Some(locations) = index.minimizers.get(&m) {
for &ref_idx in locations {
if let Some(count) = acc.get_mut(ref_idx) {
*count += 1;
}
}
}
acc
})
}

fn preprocess_seq(seq: impl AsRef<str>) -> String {
seq.as_ref().to_uppercase().replace('-', "")
// Create a bit-packed representation of the kmer
// where each nucleotide is represented by 2 bits:
// A=11, C=10, G=00, T=01
// Invalid nucleotides are represented by INVALID_NUCLEOTIDE_VALUE
fn preprocess_seq(seq: &str) -> Vec<u8>{
let mut result = Vec::with_capacity(seq.len());
result.extend(seq.bytes().filter_map(|b| {
if b == b'-' {
None
} else {
Some(NUCLEOTIDE_LOOKUP[b as usize])
}
}));
result
}