1 // Copyright (c) Facebook, Inc. and its affiliates.
3 // This source code is licensed under the MIT license found in the
4 // LICENSE file in the "hack" directory of this source tree.
6 use std::cmp::{Ord, Ordering, PartialOrd};
7 use std::hash::{Hash, Hasher};
9 use serde::{Deserialize, Serialize};
11 use arena_deserializer::impl_deserialize_in_arena;
12 use arena_trait::{Arena, TrivialDrop};
13 use ocamlrep::{FromOcamlRepIn, ToOcamlRep};
14 use ocamlrep_derive::ToOcamlRep;
16 /// The maximum height difference (or balance factor) that is allowed
17 /// in the implementation of the AVL tree.
18 const MAX_DELTA: usize = 2;
20 /// An arena-allocated map.
22 /// The underlying is a balanced tree in which all tree
23 /// nodes are allocated inside an arena.
25 /// Currently the underlying tree is equivalent to the
26 /// one used in OCaml's `Map` type.
28 /// Note that the `Option<&'a T>` is optimized to have a size of 1 word.
30 /// Since the whole Map is just a 1 word pointer, it implements the
33 #[derive(Debug, Deserialize, Serialize)]
35 deserialize = "K: 'de + arena_deserializer::DeserializeInArena<'de>, V: 'de + arena_deserializer::DeserializeInArena<'de>"
38 pub struct Map<'a, K, V>(
39 #[serde(deserialize_with = "arena_deserializer::arena", borrow)] Option<&'a Node<'a, K, V>>,
42 impl_deserialize_in_arena!(Map<'arena, K, V>);
44 impl<'a, K, V> TrivialDrop for Map<'a, K, V> {}
46 /// The derived implementations of Copy and Clone require that K and V be
47 /// Copy/Clone. We have no such requirement, since Map is just a pointer, so we
48 /// manually implement them here.
49 impl<'a, K, V> Clone for Map<'a, K, V> {
50 fn clone(&self) -> Self {
56 impl<'a, K, V> Copy for Map<'a, K, V> {}
58 impl<'a, K: PartialEq, V: PartialEq> PartialEq for Map<'a, K, V> {
59 fn eq(&self, other: &Self) -> bool {
60 self.iter().eq(other.iter())
64 impl<'a, K: Eq, V: Eq> Eq for Map<'a, K, V> {}
66 impl<K: PartialOrd, V: PartialOrd> PartialOrd for Map<'_, K, V> {
68 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
69 self.iter().partial_cmp(other.iter())
73 impl<K: Ord, V: Ord> Ord for Map<'_, K, V> {
75 fn cmp(&self, other: &Self) -> Ordering {
76 self.iter().cmp(other.iter())
80 impl<K: Hash, V: Hash> Hash for Map<'_, K, V> {
81 fn hash<H: Hasher>(&self, state: &mut H) {
88 impl<K, V> Default for Map<'_, K, V> {
89 fn default() -> Self {
94 impl<K: ToOcamlRep, V: ToOcamlRep> ToOcamlRep for Map<'_, K, V> {
95 fn to_ocamlrep<'a, A: ocamlrep::Allocator>(
98 ) -> ocamlrep::OpaqueValue<'a> {
100 None => alloc.add(&()),
101 Some(val) => alloc.add(val),
106 impl<'a, K, V> FromOcamlRepIn<'a> for Map<'a, K, V>
108 K: FromOcamlRepIn<'a> + TrivialDrop,
109 V: FromOcamlRepIn<'a> + TrivialDrop,
112 value: ocamlrep::Value<'_>,
113 alloc: &'a bumpalo::Bump,
114 ) -> Result<Self, ocamlrep::FromError> {
115 if value.is_immediate() {
116 let _ = ocamlrep::from::expect_nullary_variant(value, 0)?;
120 alloc.alloc(<Node<'a, K, V>>::from_ocamlrep_in(value, alloc)?),
126 #[derive(Debug, Deserialize, Serialize, ToOcamlRep)]
128 deserialize = "K: 'de + arena_deserializer::DeserializeInArena<'de>, V: 'de + arena_deserializer::DeserializeInArena<'de>"
130 struct Node<'a, K, V>(
131 #[serde(deserialize_with = "arena_deserializer::arena", borrow)] Map<'a, K, V>,
132 #[serde(deserialize_with = "arena_deserializer::arena")] K,
133 #[serde(deserialize_with = "arena_deserializer::arena")] V,
134 #[serde(deserialize_with = "arena_deserializer::arena", borrow)] Map<'a, K, V>,
138 impl_deserialize_in_arena!(Node<'arena, K, V>);
140 impl<'a, K, V> FromOcamlRepIn<'a> for Node<'a, K, V>
142 K: FromOcamlRepIn<'a> + TrivialDrop,
143 V: FromOcamlRepIn<'a> + TrivialDrop,
146 value: ocamlrep::Value<'_>,
147 alloc: &'a bumpalo::Bump,
148 ) -> std::result::Result<Self, ocamlrep::FromError> {
149 let block = ocamlrep::from::expect_tuple(value, 5)?;
151 ocamlrep::from::field_in(block, 0, alloc)?,
152 ocamlrep::from::field_in(block, 1, alloc)?,
153 ocamlrep::from::field_in(block, 2, alloc)?,
154 ocamlrep::from::field_in(block, 3, alloc)?,
155 ocamlrep::from::field_in(block, 4, alloc)?,
160 impl<'a, K: TrivialDrop, V: TrivialDrop> TrivialDrop for Node<'a, K, V> {}
164 ( ) => ({ Map::empty() });
165 ( $arena:expr; $($x:expr => $y:expr),* ) => ({
166 let mut temp_map = Map::empty();
168 temp_map = temp_map.add($arena, $x, $y);
174 impl<'a, K, V> Map<'a, K, V> {
175 pub fn keys(&self) -> impl Iterator<Item = &'a K> {
176 self.iter().map(|(k, _v)| k)
180 impl<'a, K: Ord, V> Map<'a, K, V> {
181 /// Check whether a key is present in the map.
182 pub fn mem(self, x: &K) -> bool {
185 Map(Some(Node(l, v, _d, r, _h))) => match x.cmp(v) {
186 Ordering::Equal => true,
187 Ordering::Less => l.mem(x),
188 Ordering::Greater => r.mem(x),
194 impl<'a, K, V> Map<'a, K, V> {
195 /// Create a new empty map.
197 /// Note that this does not require heap allocation,
198 /// as it is equivalent to a 1 word null pointer.
199 pub const fn empty() -> Self {
203 /// Compute the number of entries in a map.
205 /// Note that this function takes linear time and logarithmic
206 /// stack space in the size of the map.
207 pub fn count(self) -> usize {
210 Map(Some(Node(l, _, _, r, _))) => l.count() + 1 + r.count(),
215 impl<'a, K: Ord, V> Map<'a, K, V> {
216 /// Check whether the map is empty.
217 pub fn is_empty(self) -> bool {
220 Map(Some(_)) => false,
225 impl<'a, K: TrivialDrop + Clone + Ord, V: TrivialDrop + Clone> Map<'a, K, V> {
226 /// Returns a one-element map.
227 pub fn singleton<A: Arena>(arena: &'a A, x: K, d: V) -> Self {
228 let node = Node(Map(None), x, d, Map(None), 1);
229 return Map(Some(arena.alloc(node)));
232 /// Create a map from an iterator.
233 pub fn from<A: Arena, I>(arena: &'a A, i: I) -> Self
235 I: IntoIterator<Item = (K, V)>,
237 let mut m = Self::empty();
240 m = m.add(arena, k, v);
246 /// Returns a pointer the current entry belonging to the key,
247 /// or returns None, if no such entry exists.
248 pub fn get(self, x: &K) -> Option<&'a V> {
251 Map(Some(Node(l, v, d, r, _h))) => match x.cmp(v) {
252 Ordering::Equal => Some(d),
253 Ordering::Less => l.get(x),
254 Ordering::Greater => r.get(x),
259 /// Return a map containing the same entries as before,
260 /// plus a new entry. If the key was already bound,
261 /// its previous entry disappears.
262 pub fn add<A: Arena>(self, arena: &'a A, x: K, data: V) -> Self {
265 let node = Node(Self::empty(), x, data, Self::empty(), 1);
266 Map(Some(arena.alloc(node)))
268 Map(Some(Node(ref l, v, d, r, h))) => match x.cmp(v) {
270 let node = Node(*l, x, data, *r, *h);
271 Map(Some(arena.alloc(node)))
273 Ordering::Less => bal(arena, l.add(arena, x, data), v.clone(), d.clone(), *r),
274 Ordering::Greater => bal(arena, *l, v.clone(), d.clone(), r.add(arena, x, data)),
279 /// Returns a map containing the same entries as before,
280 /// except for the key, which is unbound in the returned map.
281 pub fn remove<A: Arena>(self, arena: &'a A, x: &K) -> Self {
283 Map(None) => Map(None),
284 Map(Some(Node(l, v, d, r, _))) => match x.cmp(v) {
285 Ordering::Equal => merge(arena, *l, *r),
286 Ordering::Less => bal(arena, l.remove(arena, x), v.clone(), d.clone(), *r),
287 Ordering::Greater => bal(arena, *l, v.clone(), d.clone(), r.remove(arena, x)),
292 pub fn add_all<A: Arena>(self, arena: &'a A, other: Self) -> Self {
295 .fold(self, |m, (k, v)| m.add(arena, k.clone(), v.clone()))
298 /// Find the minimal key-value entry.
299 pub fn min_entry(self) -> Option<(&'a K, &'a V)> {
302 Map(Some(Node(l, x, d, _r, _))) => match l {
303 Map(None) => Some((x, d)),
309 /// Remove the minimal key-value entry.
310 pub fn remove_min_entry<A: Arena>(self, arena: &'a A) -> Self {
312 Map(None) => Map(None),
313 Map(Some(Node(l, x, d, r, _))) => match l {
315 l => bal(arena, l.remove_min_entry(arena), x.clone(), d.clone(), *r),
320 /// Find the maximum key-value entry.
321 pub fn max_entry(self) -> Option<(&'a K, &'a V)> {
324 Map(Some(Node(_l, x, d, r, _))) => match r {
325 Map(None) => Some((x, d)),
331 /// Remove the maximum key-value entry.
332 pub fn remove_max_entry<A: Arena>(self, arena: &'a A) -> Self {
334 Map(None) => Map(None),
335 Map(Some(Node(l, x, d, r, _))) => match r {
337 r => bal(arena, *l, x.clone(), d.clone(), r.remove_max_entry(arena)),
342 /// Set difference. O(n*log(n))
343 pub fn diff<A: Arena>(self, arena: &'a A, other: Self) -> Self {
346 .fold(self, |set, (k, _v)| set.remove(arena, k))
350 impl<'a, K: Clone + Ord, V: Copy> Map<'a, K, V> {
351 /// Returns a copy of the current entry belonging to the key,
352 /// or returns None, if no such entry exists.
353 pub fn find(self, x: &K) -> Option<V> {
356 Map(Some(Node(l, v, d, r, _h))) => match x.cmp(v) {
357 Ordering::Equal => Some(*d),
358 Ordering::Less => l.find(x),
359 Ordering::Greater => r.find(x),
365 fn height<'a, K, V>(l: Map<'a, K, V>) -> usize {
368 Map(Some(Node(_, _, _, _, h))) => *h,
372 fn create<'a, A: Arena, K: TrivialDrop, V: TrivialDrop>(
381 let h = if hl >= hr { hl + 1 } else { hr + 1 };
382 let node = Node(l, x, v, r, h);
383 Map(Some(arena.alloc(node)))
386 fn bal<'a, A: Arena, K: TrivialDrop + Clone, V: TrivialDrop + Clone>(
395 if hl > hr + MAX_DELTA {
397 Map(None) => panic!("impossible"),
398 Map(Some(Node(ll, lv, ld, lr, _))) => {
399 if height(*ll) >= height(*lr) {
405 create(arena, *lr, x, d, r),
409 Map(None) => panic!("impossible"),
410 Map(Some(Node(lrl, lrv, lrd, lrr, _))) => create(
412 create(arena, *ll, lv.clone(), ld.clone(), *lrl),
415 create(arena, *lrr, x, d, r),
421 } else if hr > hl + MAX_DELTA {
423 Map(None) => panic!("impossible"),
424 Map(Some(Node(rl, rv, rd, rr, _))) => {
425 if height(*rr) >= height(*rl) {
428 create(arena, l, x, d, *rl),
435 Map(None) => panic!("impossible"),
436 Map(Some(Node(rll, rlv, rld, rlr, _))) => create(
438 create(arena, l, x, d, *rll),
441 create(arena, *rlr, rv.clone(), rd.clone(), *rr),
448 create(arena, l, x, d, r)
452 fn merge<'a, A: Arena, K: TrivialDrop + Clone + Ord, V: TrivialDrop + Clone>(
459 } else if t2.is_empty() {
462 let (x, d) = t2.min_entry().unwrap();
463 bal(arena, t1, x.clone(), d.clone(), t2.remove_min_entry(arena))
467 /// Iterator state for map.
468 pub struct MapIter<'a, K, V> {
469 stack: Vec<NodeIter<'a, K, V>>,
472 struct NodeIter<'a, K, V> {
474 node: &'a Node<'a, K, V>,
477 impl<'a, K, V> Map<'a, K, V> {
478 pub fn iter(&self) -> MapIter<'a, K, V> {
479 let stack = match self {
480 Map(None) => Vec::new(),
481 Map(Some(root)) => vec![NodeIter {
490 impl<'a, K, V> IntoIterator for &Map<'a, K, V> {
491 type Item = (&'a K, &'a V);
492 type IntoIter = MapIter<'a, K, V>;
494 fn into_iter(self) -> Self::IntoIter {
499 impl<'a, K, V> Iterator for MapIter<'a, K, V> {
500 type Item = (&'a K, &'a V);
502 fn next(&mut self) -> Option<Self::Item> {
504 match self.stack.last_mut() {
511 let Node(Map(l), _, _, _, _) = n.node;
519 self.stack.push(NodeIter {
525 None => match self.stack.pop() {
528 if let Node(_, _, _, Map(Some(n)), _) = n.node {
529 self.stack.push(NodeIter {
534 let Node(_, k, v, _, _) = n.node;
549 let arena = Bump::new();
550 assert!(Map::<i64, i64>::empty().is_empty());
551 assert!(!map![&arena; 4 => 9].is_empty());
555 fn test_singleton() {
556 let a1 = Bump::new();
557 let a2 = Bump::new();
558 assert_eq!(map![&a1; 1 => 2], map![&a2; 1 => 2]);
564 assert_eq!(Map::<i64, i64>::from(&a, Vec::new()), map![]);
566 Map::<i64, i64>::from(&a, vec![(6, 7), (8, 9)]),
567 map![&a; 8 => 9, 6 => 7]
573 mod tests_arbitrary {
577 use std::collections::{BTreeMap, BTreeSet, HashMap};
581 fn prop_mem_find(xs: Vec<(u32, u32)>, ys: Vec<u32>) -> bool {
583 let m = Map::from(&a, xs.clone());
584 let o: HashMap<u32, u32> = xs.into_iter().collect();
585 for (k, v) in o.iter() {
587 assert_eq!(m.get(k), Some(v));
590 let f = o.contains_key(&k);
591 assert_eq!(m.mem(&k), f);
592 assert_eq!(m.get(&k).is_some(), f);
598 #[derive(Clone, Debug)]
604 #[derive(Clone, Debug)]
605 struct ActionSequence<T, V>(Vec<Action<T, V>>);
607 struct ActionSequenceShrinker<T, V> {
608 seed: ActionSequence<T, V>,
612 impl<T: Arbitrary + Ord + Hash, V: Arbitrary + Eq + Hash> Arbitrary for ActionSequence<T, V> {
613 fn arbitrary(g: &mut Gen) -> Self {
616 usize::arbitrary(g) % s
618 let mut elements: BTreeSet<T> = BTreeSet::new();
619 let mut actions: Vec<Action<T, V>> = Vec::with_capacity(size);
621 let r = f64::arbitrary(g);
623 let key: T = Arbitrary::arbitrary(g);
624 elements.remove(&key);
625 actions.push(Action::Remove(key));
626 } else if !elements.is_empty() && r < 0.3 {
627 let index = usize::arbitrary(g) % elements.len();
628 let key: T = elements.iter().nth(index).unwrap().clone();
629 elements.remove(&key);
630 actions.push(Action::Remove(key));
632 let key: T = Arbitrary::arbitrary(g);
633 elements.insert(key.clone());
634 actions.push(Action::Add(key, Arbitrary::arbitrary(g)));
637 ActionSequence(actions)
640 fn shrink(&self) -> Box<dyn Iterator<Item = Self>> {
641 Box::new(ActionSequenceShrinker {
648 impl<T: Clone, V: Clone> Iterator for ActionSequenceShrinker<T, V> {
649 type Item = ActionSequence<T, V>;
651 fn next(&mut self) -> Option<ActionSequence<T, V>> {
652 let ActionSequence(ref actions) = self.seed;
653 if self.index > actions.len() {
656 let actions = actions[..self.index].to_vec();
658 Some(ActionSequence(actions))
663 fn check_height_invariant<'a, K, V>(m: Map<'a, K, V>) -> bool {
666 Map(Some(Node(l, _, _, r, h))) => {
669 let h_exp = if lh > rh { lh + 1 } else { rh + 1 };
671 println!("incorrect node height");
674 if lh > rh + MAX_DELTA || rh > lh + MAX_DELTA {
675 println!("height difference invariant violated");
678 check_height_invariant(*l) && check_height_invariant(*r)
684 fn prop_action_seq(actions: ActionSequence<u32, u32>) -> bool {
685 let ActionSequence(ref actions) = actions;
687 let mut m: Map<'_, u32, u32> = Map::empty();
688 let mut o: BTreeMap<u32, u32> = BTreeMap::new();
689 for action in actions {
691 Action::Add(key, value) => {
692 m = m.add(&a, *key, *value);
693 o.insert(*key, *value);
695 Action::Remove(key) => {
696 m = m.remove(&a, key);
701 if !m.into_iter().eq(o.iter()) {
702 println!("EXPECTED {:?} GOT {:?}", o, m);
705 check_height_invariant(m)
717 fn test_iter_manual() {
719 Map::<i64, i64>::empty()
721 .map(|(k, v)| (*k, *v))
722 .collect::<Vec<(i64, i64)>>(),
730 let arena = Bump::new();
732 let empty = Map::empty();
737 create(a, empty, 1, (), empty),
740 create(a, create(a, empty, 3, (), empty), 4, (), empty),
744 create(a, empty, 6, (), create(a, empty, 7, (), empty)),
747 map.into_iter().map(|(k, ())| *k).collect::<Vec<i64>>(),
748 vec![1, 2, 3, 4, 5, 6, 7]