use std::mem;
use itertools::Itertools;
use std::ops::Deref;
use std::marker::PhantomData;
use std::hash::Hash;
use concurrent_hashmap::*;
use std::sync::atomic::{AtomicUsize, Ordering};
use log::debug;
use crate::Dir;
use crate::Kmer;
use crate::Exts;
use crate::Vmer;
use boomphf::hashmap::BoomHashMap2;
use std::fmt::Debug;
fn bucket<K: Kmer>(kmer: K) -> usize {
(kmer.get(0) as usize) << 6 | (kmer.get(1) as usize) << 4 | (kmer.get(2) as usize) << 2 | (kmer.get(3) as usize)
}
pub trait KmerSummarizer<DI, DO> {
fn summarize<K, F: Iterator<Item = (K, Exts, DI)>>(&self, items: F) -> (bool, Exts, DO);
}
pub struct CountFilter {
min_kmer_obs: usize,
}
impl CountFilter {
pub fn new(min_kmer_obs: usize) -> CountFilter {
CountFilter { min_kmer_obs: min_kmer_obs }
}
}
impl<D> KmerSummarizer<D, u16> for CountFilter {
fn summarize<K, F: Iterator<Item = (K, Exts, D)>>(&self, items: F) -> (bool, Exts, u16) {
let mut all_exts = Exts::empty();
let mut count = 0u16;
for (_, exts, _) in items {
count = count.saturating_add(1);
all_exts = all_exts.add(exts);
}
(count as usize >= self.min_kmer_obs, all_exts, count)
}
}
pub struct CountFilterSet<D> {
min_kmer_obs: usize,
phantom: PhantomData<D>,
}
impl<D> CountFilterSet<D> {
pub fn new(min_kmer_obs: usize) -> CountFilterSet<D> {
CountFilterSet {
min_kmer_obs: min_kmer_obs,
phantom: PhantomData,
}
}
}
impl<D: Ord> KmerSummarizer<D, Vec<D>> for CountFilterSet<D> {
fn summarize<K, F: Iterator<Item = (K, Exts, D)>>(&self, items: F) -> (bool, Exts, Vec<D>) {
let mut all_exts = Exts::empty();
let mut out_data: Vec<D> = Vec::new();
let mut nobs = 0;
for (_, exts, d) in items {
out_data.push(d);
all_exts = all_exts.add(exts);
nobs += 1;
}
out_data.sort();
out_data.dedup();
(nobs as usize >= self.min_kmer_obs, all_exts, out_data)
}
}
pub type EqClassIdType = u32 ;
pub struct CountFilterEqClass<D: Eq + Hash + Send + Sync + Debug + Clone> {
min_kmer_obs: usize,
eq_classes: ConcHashMap<Vec<D>, EqClassIdType>,
num_eq_classes: AtomicUsize,
}
impl<D: Eq + Hash + Send + Sync + Debug + Clone> CountFilterEqClass<D> {
pub fn new(min_kmer_obs: usize) -> CountFilterEqClass<D> {
CountFilterEqClass {
min_kmer_obs: min_kmer_obs,
eq_classes: ConcHashMap::<Vec<D>, EqClassIdType>::new(),
num_eq_classes: AtomicUsize::new(0),
}
}
pub fn get_eq_classes(&self) -> Vec<Vec<D>>{
let mut eq_class_vec = Vec::new();
eq_class_vec.resize(self.get_number_of_eq_classes(), Vec::new());
for (key, value) in self.eq_classes.iter() {
eq_class_vec[*value as usize] = key.clone();
}
eq_class_vec
}
pub fn get_number_of_eq_classes(&self) -> usize{
self.num_eq_classes.load(Ordering::SeqCst)
}
pub fn fetch_add(&self) -> usize {
self.num_eq_classes.fetch_add(1, Ordering::SeqCst)
}
}
impl<D: Eq + Ord + Hash + Send + Sync + Debug + Clone> KmerSummarizer<D, EqClassIdType> for CountFilterEqClass<D> {
fn summarize<K, F: Iterator<Item = (K, Exts, D)>>(&self, items: F) -> (bool, Exts, EqClassIdType) {
let mut all_exts = Exts::empty();
let mut out_data = Vec::new();
let mut nobs = 0;
for (_, exts, d) in items {
out_data.push(d);
all_exts = all_exts.add(exts);
nobs += 1;
}
out_data.sort(); out_data.dedup();
let eq_id: EqClassIdType = match self.eq_classes.find(&out_data) {
Some(val) => *val.get(),
None => {
let val = self.fetch_add() as EqClassIdType;
self.eq_classes.insert(out_data, val);
val
},
};
(nobs as usize >= self.min_kmer_obs, all_exts, eq_id)
}
}
#[inline(never)]
pub fn filter_kmers<K: Kmer, V: Vmer, D1: Clone, DS, S: KmerSummarizer<D1, DS>>(
seqs: &[(V, Exts, D1)],
summarizer: &dyn Deref<Target=S>,
stranded: bool,
report_all_kmers: bool,
memory_size: usize,
) -> ( BoomHashMap2<K, Exts, DS>, Vec<K> )
where DS: Debug{
let rc_norm = !stranded;
let mut all_kmers = Vec::new();
let mut valid_kmers = Vec::new();
let mut valid_exts = Vec::new();
let mut valid_data = Vec::new();
let input_kmers: usize = seqs.iter()
.map(|&(ref vmer, _, _)| vmer.len().saturating_sub(K::k() - 1))
.sum();
let kmer_mem = input_kmers * mem::size_of::<(K, D1)>();
let max_mem = memory_size * (10 as usize).pow(9);
let slices = kmer_mem / max_mem + 1;
let sz = 256 / slices + 1;
let mut bucket_ranges = Vec::new();
let mut start = 0;
while start < 256 {
bucket_ranges.push(start..start + sz);
start += sz;
}
assert!(bucket_ranges[bucket_ranges.len() - 1].end >= 256);
let n_buckets = bucket_ranges.len();
if bucket_ranges.len() > 1 {
debug!(
"{} sequences, {} kmers, {} passes",
seqs.len(),
input_kmers,
bucket_ranges.len()
);
}
for (i, bucket_range) in bucket_ranges.into_iter().enumerate() {
debug!("Processing bucket {} of {}", i, n_buckets);
let mut kmer_buckets: Vec<Vec<(K, Exts, D1)>> = Vec::new();
for _ in 0..256 {
kmer_buckets.push(Vec::new());
}
for &(ref seq, seq_exts, ref d) in seqs {
for (kmer, exts) in seq.iter_kmer_exts::<K>(seq_exts) {
let (min_kmer, flip_exts) = if rc_norm {
let (min_kmer, flip) = kmer.min_rc_flip();
let flip_exts = if flip { exts.rc() } else { exts };
(min_kmer, flip_exts)
} else {
(kmer, exts)
};
let bucket = bucket(min_kmer);
if bucket >= bucket_range.start && bucket < bucket_range.end {
kmer_buckets[bucket].push((min_kmer, flip_exts, d.clone()));
}
}
}
for mut kmer_vec in kmer_buckets {
kmer_vec.sort_by_key(|elt| elt.0);
for (kmer, kmer_obs_iter) in &kmer_vec.into_iter().group_by(|elt| elt.0) {
let (is_valid, exts, summary_data) = summarizer.summarize(kmer_obs_iter);
if report_all_kmers { all_kmers.push(kmer); }
if is_valid {
valid_kmers.push(kmer);
valid_exts.push(exts);
valid_data.push(summary_data);
}
}
}
}
debug!(
"Unique kmers: {}, All kmers (if returned): {}",
valid_kmers.len(),
all_kmers.len(),
);
(BoomHashMap2::new(valid_kmers, valid_exts, valid_data), all_kmers)
}
pub fn remove_censored_exts_sharded<K: Kmer, D>(
stranded: bool,
valid_kmers: &mut Vec<(K, (Exts, D))>,
all_kmers: &Vec<K>,
) {
for idx in 0..valid_kmers.len() {
let mut new_exts = Exts::empty();
let kmer = valid_kmers[idx].0;
let exts = (valid_kmers[idx].1).0;
for dir in [Dir::Left, Dir::Right].iter() {
for i in 0..4 {
if exts.has_ext(*dir, i) {
let _ext_kmer = kmer.extend(i, *dir);
let ext_kmer = if stranded {
_ext_kmer
} else {
_ext_kmer.min_rc()
};
let censored =
if valid_kmers.binary_search_by_key(&ext_kmer, |d| d.0).is_ok() {
false
} else {
all_kmers.binary_search(&ext_kmer).is_ok()
};
if !censored {
new_exts = new_exts.set(*dir, i);
}
}
}
}
(valid_kmers[idx].1).0 = new_exts;
}
}
pub fn remove_censored_exts<K: Kmer, D>(stranded: bool, valid_kmers: &mut Vec<(K, (Exts, D))>) {
for idx in 0..valid_kmers.len() {
let mut new_exts = Exts::empty();
let kmer = valid_kmers[idx].0;
let exts = (valid_kmers[idx].1).0;
for dir in [Dir::Left, Dir::Right].iter() {
for i in 0..4 {
if exts.has_ext(*dir, i) {
let ext_kmer = if stranded {
kmer.extend(i, *dir)
} else {
kmer.extend(i, *dir).min_rc()
};
let kmer_valid = valid_kmers.binary_search_by_key(&ext_kmer, |d| d.0).is_ok();
if kmer_valid {
new_exts = new_exts.set(*dir, i);
}
}
}
}
(valid_kmers[idx].1).0 = new_exts;
}
}