|
| 1 | +use std::collections::BinaryHeap; |
| 2 | +use std::fmt::Debug; |
| 3 | +use std::hash::Hash; |
| 4 | + |
| 5 | +use rustc_hash::FxHashMap; |
| 6 | +use smallvec::SmallVec; |
| 7 | + |
| 8 | +use crate::TantivyError; |
| 9 | + |
| 10 | +/// Map backed by a hash map for fast access and a binary heap to track the |
| 11 | +/// highest key. The key is an array of fixed size S. |
| 12 | +#[derive(Clone, Debug)] |
| 13 | +struct ArrayHeapMap<K: Ord, V, const S: usize> { |
| 14 | + pub(crate) buckets: FxHashMap<[K; S], V>, |
| 15 | + pub(crate) heap: BinaryHeap<[K; S]>, |
| 16 | +} |
| 17 | + |
| 18 | +impl<K: Ord, V, const S: usize> Default for ArrayHeapMap<K, V, S> { |
| 19 | + fn default() -> Self { |
| 20 | + ArrayHeapMap { |
| 21 | + buckets: FxHashMap::default(), |
| 22 | + heap: BinaryHeap::default(), |
| 23 | + } |
| 24 | + } |
| 25 | +} |
| 26 | + |
| 27 | +impl<K: Eq + Hash + Clone + Ord, V, const S: usize> ArrayHeapMap<K, V, S> { |
| 28 | + /// Panics if the length of `key` is not S. |
| 29 | + fn get_or_insert_with<F: FnOnce() -> V>(&mut self, key: &[K], f: F) -> &mut V { |
| 30 | + let key_array: &[K; S] = key.try_into().expect("Key length mismatch"); |
| 31 | + self.buckets.entry(key_array.clone()).or_insert_with(|| { |
| 32 | + self.heap.push(key_array.clone()); |
| 33 | + f() |
| 34 | + }) |
| 35 | + } |
| 36 | + |
| 37 | + /// Panics if the length of `key` is not S. |
| 38 | + fn get_mut(&mut self, key: &[K]) -> Option<&mut V> { |
| 39 | + let key_array: &[K; S] = key.try_into().expect("Key length mismatch"); |
| 40 | + self.buckets.get_mut(key_array) |
| 41 | + } |
| 42 | + |
| 43 | + fn peek_highest(&self) -> Option<&[K]> { |
| 44 | + self.heap.peek().map(|k_array| k_array.as_slice()) |
| 45 | + } |
| 46 | + |
| 47 | + fn evict_highest(&mut self) { |
| 48 | + if let Some(highest) = self.heap.pop() { |
| 49 | + self.buckets.remove(&highest); |
| 50 | + } |
| 51 | + } |
| 52 | +} |
| 53 | + |
| 54 | +impl<K: Copy + Ord + Clone + 'static, V: 'static, const S: usize> ArrayHeapMap<K, V, S> { |
| 55 | + fn into_iter(self) -> Box<dyn Iterator<Item = (SmallVec<[K; MAX_DYN_ARRAY_SIZE]>, V)>> { |
| 56 | + Box::new( |
| 57 | + self.buckets |
| 58 | + .into_iter() |
| 59 | + .map(|(k, v)| (SmallVec::from_slice(&k), v)), |
| 60 | + ) |
| 61 | + } |
| 62 | + |
| 63 | + fn values_mut<'a>(&'a mut self) -> Box<dyn Iterator<Item = &'a mut V> + 'a> { |
| 64 | + Box::new(self.buckets.values_mut()) |
| 65 | + } |
| 66 | +} |
| 67 | + |
| 68 | +pub(super) const MAX_DYN_ARRAY_SIZE: usize = 16; |
| 69 | +const MAX_DYN_ARRAY_SIZE_PLUS_ONE: usize = MAX_DYN_ARRAY_SIZE + 1; |
| 70 | + |
| 71 | +/// A map optimized for memory footprint, fast access and efficient eviction of |
| 72 | +/// the highest key. |
| 73 | +/// |
| 74 | +/// Keys are inlined arrays of size 1 to [MAX_DYN_ARRAY_SIZE] but for a given |
| 75 | +/// instance the key size is fixed. This allows to avoid heap allocations for the |
| 76 | +/// keys. |
| 77 | +#[derive(Clone, Debug)] |
| 78 | +pub(super) struct DynArrayHeapMap<K: Ord, V>(DynArrayHeapMapInner<K, V>); |
| 79 | + |
| 80 | +/// Wrapper around ArrayHeapMap to dynamically dispatch on the array size. |
| 81 | +#[derive(Clone, Debug)] |
| 82 | +enum DynArrayHeapMapInner<K: Ord, V> { |
| 83 | + Dim1(ArrayHeapMap<K, V, 1>), |
| 84 | + Dim2(ArrayHeapMap<K, V, 2>), |
| 85 | + Dim3(ArrayHeapMap<K, V, 3>), |
| 86 | + Dim4(ArrayHeapMap<K, V, 4>), |
| 87 | + Dim5(ArrayHeapMap<K, V, 5>), |
| 88 | + Dim6(ArrayHeapMap<K, V, 6>), |
| 89 | + Dim7(ArrayHeapMap<K, V, 7>), |
| 90 | + Dim8(ArrayHeapMap<K, V, 8>), |
| 91 | + Dim9(ArrayHeapMap<K, V, 9>), |
| 92 | + Dim10(ArrayHeapMap<K, V, 10>), |
| 93 | + Dim11(ArrayHeapMap<K, V, 11>), |
| 94 | + Dim12(ArrayHeapMap<K, V, 12>), |
| 95 | + Dim13(ArrayHeapMap<K, V, 13>), |
| 96 | + Dim14(ArrayHeapMap<K, V, 14>), |
| 97 | + Dim15(ArrayHeapMap<K, V, 15>), |
| 98 | + Dim16(ArrayHeapMap<K, V, 16>), |
| 99 | +} |
| 100 | + |
| 101 | +impl<K: Ord, V> DynArrayHeapMap<K, V> { |
| 102 | + /// Creates a new heap map with dynamic array keys of size `key_dimension`. |
| 103 | + pub(super) fn try_new(key_dimension: usize) -> crate::Result<Self> { |
| 104 | + let inner = match key_dimension { |
| 105 | + 0 => { |
| 106 | + return Err(TantivyError::InvalidArgument( |
| 107 | + "DynArrayHeapMap dimension must be at least 1".to_string(), |
| 108 | + )) |
| 109 | + } |
| 110 | + 1 => DynArrayHeapMapInner::Dim1(ArrayHeapMap::default()), |
| 111 | + 2 => DynArrayHeapMapInner::Dim2(ArrayHeapMap::default()), |
| 112 | + 3 => DynArrayHeapMapInner::Dim3(ArrayHeapMap::default()), |
| 113 | + 4 => DynArrayHeapMapInner::Dim4(ArrayHeapMap::default()), |
| 114 | + 5 => DynArrayHeapMapInner::Dim5(ArrayHeapMap::default()), |
| 115 | + 6 => DynArrayHeapMapInner::Dim6(ArrayHeapMap::default()), |
| 116 | + 7 => DynArrayHeapMapInner::Dim7(ArrayHeapMap::default()), |
| 117 | + 8 => DynArrayHeapMapInner::Dim8(ArrayHeapMap::default()), |
| 118 | + 9 => DynArrayHeapMapInner::Dim9(ArrayHeapMap::default()), |
| 119 | + 10 => DynArrayHeapMapInner::Dim10(ArrayHeapMap::default()), |
| 120 | + 11 => DynArrayHeapMapInner::Dim11(ArrayHeapMap::default()), |
| 121 | + 12 => DynArrayHeapMapInner::Dim12(ArrayHeapMap::default()), |
| 122 | + 13 => DynArrayHeapMapInner::Dim13(ArrayHeapMap::default()), |
| 123 | + 14 => DynArrayHeapMapInner::Dim14(ArrayHeapMap::default()), |
| 124 | + 15 => DynArrayHeapMapInner::Dim15(ArrayHeapMap::default()), |
| 125 | + 16 => DynArrayHeapMapInner::Dim16(ArrayHeapMap::default()), |
| 126 | + MAX_DYN_ARRAY_SIZE_PLUS_ONE.. => { |
| 127 | + return Err(TantivyError::InvalidArgument(format!( |
| 128 | + "DynArrayHeapMap supports maximum {MAX_DYN_ARRAY_SIZE} dimensions, got \ |
| 129 | + {key_dimension}", |
| 130 | + ))) |
| 131 | + } |
| 132 | + }; |
| 133 | + Ok(DynArrayHeapMap(inner)) |
| 134 | + } |
| 135 | + |
| 136 | + /// Number of elements in the map. This is not the dimension of the keys. |
| 137 | + pub(super) fn size(&self) -> usize { |
| 138 | + match &self.0 { |
| 139 | + DynArrayHeapMapInner::Dim1(map) => map.buckets.len(), |
| 140 | + DynArrayHeapMapInner::Dim2(map) => map.buckets.len(), |
| 141 | + DynArrayHeapMapInner::Dim3(map) => map.buckets.len(), |
| 142 | + DynArrayHeapMapInner::Dim4(map) => map.buckets.len(), |
| 143 | + DynArrayHeapMapInner::Dim5(map) => map.buckets.len(), |
| 144 | + DynArrayHeapMapInner::Dim6(map) => map.buckets.len(), |
| 145 | + DynArrayHeapMapInner::Dim7(map) => map.buckets.len(), |
| 146 | + DynArrayHeapMapInner::Dim8(map) => map.buckets.len(), |
| 147 | + DynArrayHeapMapInner::Dim9(map) => map.buckets.len(), |
| 148 | + DynArrayHeapMapInner::Dim10(map) => map.buckets.len(), |
| 149 | + DynArrayHeapMapInner::Dim11(map) => map.buckets.len(), |
| 150 | + DynArrayHeapMapInner::Dim12(map) => map.buckets.len(), |
| 151 | + DynArrayHeapMapInner::Dim13(map) => map.buckets.len(), |
| 152 | + DynArrayHeapMapInner::Dim14(map) => map.buckets.len(), |
| 153 | + DynArrayHeapMapInner::Dim15(map) => map.buckets.len(), |
| 154 | + DynArrayHeapMapInner::Dim16(map) => map.buckets.len(), |
| 155 | + } |
| 156 | + } |
| 157 | +} |
| 158 | + |
| 159 | +impl<K: Ord + Hash + Clone, V> DynArrayHeapMap<K, V> { |
| 160 | + /// Get a mutable reference to the value corresponding to `key` or inserts a new |
| 161 | + /// value created by calling `f`. |
| 162 | + /// |
| 163 | + /// Panics if the length of `key` does not match the key dimension of the map. |
| 164 | + pub(super) fn get_or_insert_with<F: FnOnce() -> V>(&mut self, key: &[K], f: F) -> &mut V { |
| 165 | + match &mut self.0 { |
| 166 | + DynArrayHeapMapInner::Dim1(map) => map.get_or_insert_with(key, f), |
| 167 | + DynArrayHeapMapInner::Dim2(map) => map.get_or_insert_with(key, f), |
| 168 | + DynArrayHeapMapInner::Dim3(map) => map.get_or_insert_with(key, f), |
| 169 | + DynArrayHeapMapInner::Dim4(map) => map.get_or_insert_with(key, f), |
| 170 | + DynArrayHeapMapInner::Dim5(map) => map.get_or_insert_with(key, f), |
| 171 | + DynArrayHeapMapInner::Dim6(map) => map.get_or_insert_with(key, f), |
| 172 | + DynArrayHeapMapInner::Dim7(map) => map.get_or_insert_with(key, f), |
| 173 | + DynArrayHeapMapInner::Dim8(map) => map.get_or_insert_with(key, f), |
| 174 | + DynArrayHeapMapInner::Dim9(map) => map.get_or_insert_with(key, f), |
| 175 | + DynArrayHeapMapInner::Dim10(map) => map.get_or_insert_with(key, f), |
| 176 | + DynArrayHeapMapInner::Dim11(map) => map.get_or_insert_with(key, f), |
| 177 | + DynArrayHeapMapInner::Dim12(map) => map.get_or_insert_with(key, f), |
| 178 | + DynArrayHeapMapInner::Dim13(map) => map.get_or_insert_with(key, f), |
| 179 | + DynArrayHeapMapInner::Dim14(map) => map.get_or_insert_with(key, f), |
| 180 | + DynArrayHeapMapInner::Dim15(map) => map.get_or_insert_with(key, f), |
| 181 | + DynArrayHeapMapInner::Dim16(map) => map.get_or_insert_with(key, f), |
| 182 | + } |
| 183 | + } |
| 184 | + |
| 185 | + /// Returns a mutable reference to the value corresponding to `key`. |
| 186 | + /// |
| 187 | + /// Panics if the length of `key` does not match the key dimension of the map. |
| 188 | + pub fn get_mut(&mut self, key: &[K]) -> Option<&mut V> { |
| 189 | + match &mut self.0 { |
| 190 | + DynArrayHeapMapInner::Dim1(map) => map.get_mut(key), |
| 191 | + DynArrayHeapMapInner::Dim2(map) => map.get_mut(key), |
| 192 | + DynArrayHeapMapInner::Dim3(map) => map.get_mut(key), |
| 193 | + DynArrayHeapMapInner::Dim4(map) => map.get_mut(key), |
| 194 | + DynArrayHeapMapInner::Dim5(map) => map.get_mut(key), |
| 195 | + DynArrayHeapMapInner::Dim6(map) => map.get_mut(key), |
| 196 | + DynArrayHeapMapInner::Dim7(map) => map.get_mut(key), |
| 197 | + DynArrayHeapMapInner::Dim8(map) => map.get_mut(key), |
| 198 | + DynArrayHeapMapInner::Dim9(map) => map.get_mut(key), |
| 199 | + DynArrayHeapMapInner::Dim10(map) => map.get_mut(key), |
| 200 | + DynArrayHeapMapInner::Dim11(map) => map.get_mut(key), |
| 201 | + DynArrayHeapMapInner::Dim12(map) => map.get_mut(key), |
| 202 | + DynArrayHeapMapInner::Dim13(map) => map.get_mut(key), |
| 203 | + DynArrayHeapMapInner::Dim14(map) => map.get_mut(key), |
| 204 | + DynArrayHeapMapInner::Dim15(map) => map.get_mut(key), |
| 205 | + DynArrayHeapMapInner::Dim16(map) => map.get_mut(key), |
| 206 | + } |
| 207 | + } |
| 208 | + |
| 209 | + /// Returns a reference to the highest key in the map. |
| 210 | + pub(super) fn peek_highest(&self) -> Option<&[K]> { |
| 211 | + match &self.0 { |
| 212 | + DynArrayHeapMapInner::Dim1(map) => map.peek_highest(), |
| 213 | + DynArrayHeapMapInner::Dim2(map) => map.peek_highest(), |
| 214 | + DynArrayHeapMapInner::Dim3(map) => map.peek_highest(), |
| 215 | + DynArrayHeapMapInner::Dim4(map) => map.peek_highest(), |
| 216 | + DynArrayHeapMapInner::Dim5(map) => map.peek_highest(), |
| 217 | + DynArrayHeapMapInner::Dim6(map) => map.peek_highest(), |
| 218 | + DynArrayHeapMapInner::Dim7(map) => map.peek_highest(), |
| 219 | + DynArrayHeapMapInner::Dim8(map) => map.peek_highest(), |
| 220 | + DynArrayHeapMapInner::Dim9(map) => map.peek_highest(), |
| 221 | + DynArrayHeapMapInner::Dim10(map) => map.peek_highest(), |
| 222 | + DynArrayHeapMapInner::Dim11(map) => map.peek_highest(), |
| 223 | + DynArrayHeapMapInner::Dim12(map) => map.peek_highest(), |
| 224 | + DynArrayHeapMapInner::Dim13(map) => map.peek_highest(), |
| 225 | + DynArrayHeapMapInner::Dim14(map) => map.peek_highest(), |
| 226 | + DynArrayHeapMapInner::Dim15(map) => map.peek_highest(), |
| 227 | + DynArrayHeapMapInner::Dim16(map) => map.peek_highest(), |
| 228 | + } |
| 229 | + } |
| 230 | + |
| 231 | + /// Removes the entry with the highest key from the map. |
| 232 | + pub(super) fn evict_highest(&mut self) { |
| 233 | + match &mut self.0 { |
| 234 | + DynArrayHeapMapInner::Dim1(map) => map.evict_highest(), |
| 235 | + DynArrayHeapMapInner::Dim2(map) => map.evict_highest(), |
| 236 | + DynArrayHeapMapInner::Dim3(map) => map.evict_highest(), |
| 237 | + DynArrayHeapMapInner::Dim4(map) => map.evict_highest(), |
| 238 | + DynArrayHeapMapInner::Dim5(map) => map.evict_highest(), |
| 239 | + DynArrayHeapMapInner::Dim6(map) => map.evict_highest(), |
| 240 | + DynArrayHeapMapInner::Dim7(map) => map.evict_highest(), |
| 241 | + DynArrayHeapMapInner::Dim8(map) => map.evict_highest(), |
| 242 | + DynArrayHeapMapInner::Dim9(map) => map.evict_highest(), |
| 243 | + DynArrayHeapMapInner::Dim10(map) => map.evict_highest(), |
| 244 | + DynArrayHeapMapInner::Dim11(map) => map.evict_highest(), |
| 245 | + DynArrayHeapMapInner::Dim12(map) => map.evict_highest(), |
| 246 | + DynArrayHeapMapInner::Dim13(map) => map.evict_highest(), |
| 247 | + DynArrayHeapMapInner::Dim14(map) => map.evict_highest(), |
| 248 | + DynArrayHeapMapInner::Dim15(map) => map.evict_highest(), |
| 249 | + DynArrayHeapMapInner::Dim16(map) => map.evict_highest(), |
| 250 | + } |
| 251 | + } |
| 252 | +} |
| 253 | + |
| 254 | +impl<K: Ord + Clone + Copy + 'static, V: 'static> DynArrayHeapMap<K, V> { |
| 255 | + /// Turns this map into an iterator over key-value pairs. |
| 256 | + pub fn into_iter(self) -> impl Iterator<Item = (SmallVec<[K; MAX_DYN_ARRAY_SIZE]>, V)> { |
| 257 | + match self.0 { |
| 258 | + DynArrayHeapMapInner::Dim1(map) => map.into_iter(), |
| 259 | + DynArrayHeapMapInner::Dim2(map) => map.into_iter(), |
| 260 | + DynArrayHeapMapInner::Dim3(map) => map.into_iter(), |
| 261 | + DynArrayHeapMapInner::Dim4(map) => map.into_iter(), |
| 262 | + DynArrayHeapMapInner::Dim5(map) => map.into_iter(), |
| 263 | + DynArrayHeapMapInner::Dim6(map) => map.into_iter(), |
| 264 | + DynArrayHeapMapInner::Dim7(map) => map.into_iter(), |
| 265 | + DynArrayHeapMapInner::Dim8(map) => map.into_iter(), |
| 266 | + DynArrayHeapMapInner::Dim9(map) => map.into_iter(), |
| 267 | + DynArrayHeapMapInner::Dim10(map) => map.into_iter(), |
| 268 | + DynArrayHeapMapInner::Dim11(map) => map.into_iter(), |
| 269 | + DynArrayHeapMapInner::Dim12(map) => map.into_iter(), |
| 270 | + DynArrayHeapMapInner::Dim13(map) => map.into_iter(), |
| 271 | + DynArrayHeapMapInner::Dim14(map) => map.into_iter(), |
| 272 | + DynArrayHeapMapInner::Dim15(map) => map.into_iter(), |
| 273 | + DynArrayHeapMapInner::Dim16(map) => map.into_iter(), |
| 274 | + } |
| 275 | + } |
| 276 | + |
| 277 | + /// Returns an iterator over mutable references to the values in the map. |
| 278 | + pub(super) fn values_mut(&mut self) -> impl Iterator<Item = &mut V> { |
| 279 | + match &mut self.0 { |
| 280 | + DynArrayHeapMapInner::Dim1(map) => map.values_mut(), |
| 281 | + DynArrayHeapMapInner::Dim2(map) => map.values_mut(), |
| 282 | + DynArrayHeapMapInner::Dim3(map) => map.values_mut(), |
| 283 | + DynArrayHeapMapInner::Dim4(map) => map.values_mut(), |
| 284 | + DynArrayHeapMapInner::Dim5(map) => map.values_mut(), |
| 285 | + DynArrayHeapMapInner::Dim6(map) => map.values_mut(), |
| 286 | + DynArrayHeapMapInner::Dim7(map) => map.values_mut(), |
| 287 | + DynArrayHeapMapInner::Dim8(map) => map.values_mut(), |
| 288 | + DynArrayHeapMapInner::Dim9(map) => map.values_mut(), |
| 289 | + DynArrayHeapMapInner::Dim10(map) => map.values_mut(), |
| 290 | + DynArrayHeapMapInner::Dim11(map) => map.values_mut(), |
| 291 | + DynArrayHeapMapInner::Dim12(map) => map.values_mut(), |
| 292 | + DynArrayHeapMapInner::Dim13(map) => map.values_mut(), |
| 293 | + DynArrayHeapMapInner::Dim14(map) => map.values_mut(), |
| 294 | + DynArrayHeapMapInner::Dim15(map) => map.values_mut(), |
| 295 | + DynArrayHeapMapInner::Dim16(map) => map.values_mut(), |
| 296 | + } |
| 297 | + } |
| 298 | +} |
| 299 | + |
| 300 | +#[cfg(test)] |
| 301 | +mod tests { |
| 302 | + use super::*; |
| 303 | + |
| 304 | + #[test] |
| 305 | + fn test_dyn_array_heap_map() { |
| 306 | + let mut map = DynArrayHeapMap::<u32, &str>::try_new(2).unwrap(); |
| 307 | + // insert |
| 308 | + let key1 = [1u32, 2u32]; |
| 309 | + let key2 = [2u32, 1u32]; |
| 310 | + map.get_or_insert_with(&key1, || "a"); |
| 311 | + map.get_or_insert_with(&key2, || "b"); |
| 312 | + assert_eq!(map.size(), 2); |
| 313 | + |
| 314 | + // evict highest |
| 315 | + assert_eq!(map.peek_highest(), Some(&key2[..])); |
| 316 | + map.evict_highest(); |
| 317 | + assert_eq!(map.size(), 1); |
| 318 | + assert_eq!(map.peek_highest(), Some(&key1[..])); |
| 319 | + |
| 320 | + // mutable iterator |
| 321 | + { |
| 322 | + let mut mut_iter = map.values_mut(); |
| 323 | + let v = mut_iter.next().unwrap(); |
| 324 | + assert_eq!(*v, "a"); |
| 325 | + *v = "c"; |
| 326 | + assert_eq!(mut_iter.next(), None); |
| 327 | + } |
| 328 | + |
| 329 | + // into_iter |
| 330 | + let mut iter = map.into_iter(); |
| 331 | + let (k, v) = iter.next().unwrap(); |
| 332 | + assert_eq!(k.as_slice(), &key1); |
| 333 | + assert_eq!(v, "c"); |
| 334 | + assert_eq!(iter.next(), None); |
| 335 | + } |
| 336 | +} |
0 commit comments