about summary refs log tree commit diff stats
path: root/crates/core/src/egraph/union_find.rs
diff options
context:
space:
mode:
Diffstat (limited to 'crates/core/src/egraph/union_find.rs')
-rw-r--r--crates/core/src/egraph/union_find.rs363
1 files changed, 360 insertions, 3 deletions
diff --git a/crates/core/src/egraph/union_find.rs b/crates/core/src/egraph/union_find.rs
index 8d5422d..daae62c 100644
--- a/crates/core/src/egraph/union_find.rs
+++ b/crates/core/src/egraph/union_find.rs
@@ -1,6 +1,363 @@
 //! Union-find / disjoint-set data structure implementations.
 
-/// Simple union-find implementation.
+use alloc::{
+    alloc::{Allocator, Global},
+    collections::TryReserveError,
+    vec::Vec,
+};
+use core::cmp;
+
+use crate::id::Identifier;
+
+//NOTE: maybe we should look into finding a way to bind identifiers to the
+//      isize cap of vector sizes to avoid the error condition that would
+//      result from saturating this bound?
+
+/// Simple union-find implementation using union by min-id.
 ///
-/// Operates according to a union by min-id scheme.
-pub struct UnionFind {}
+/// All [`Identifier`]s are assumed to be their own parents by default.
+#[derive(Clone)]
+#[cfg_attr(feature = "core-fmt", derive(Debug))]
+pub struct UnionFind<Id, A = Global>
+where
+    A: Allocator + Clone,
+{
+    /// Vector mapping an e-class to its parent e-class.
+    ///
+    /// # Invariants
+    ///
+    /// For every stored parent `p`, `p.into_usize() < self.inner.len()`; and
+    /// for every `i < self.inner.len()`, `Id::from_usize(i)` succeeds.
+    pub(crate) inner: Vec<Id, A>,
+
+    /// Allocator used by the [`UnionFind`].
+    pub(crate) alloc: A,
+}
+
+impl<Id> UnionFind<Id, Global> {
+    /// Constructs an empty union-find data structure.
+    #[must_use]
+    #[inline]
+    pub fn new() -> Self {
+        Self::new_in(Global)
+    }
+}
+
+impl<Id, A> UnionFind<Id, A>
+where
+    A: Allocator + Clone,
+{
+    /// Constructs an empty union-find data structure using the provided
+    /// [`Allocator`].
+    #[must_use]
+    #[inline]
+    pub fn new_in(alloc: A) -> Self {
+        Self {
+            inner: Vec::new_in(alloc.clone()),
+            alloc,
+        }
+    }
+}
+
+impl<Id, A> UnionFind<Id, A>
+where
+    A: Allocator + Clone,
+    Id: Identifier,
+{
+    /// Helper method for converting a `usize` into an [`Identifier`],
+    /// encoding our core assumption.
+    #[must_use]
+    #[inline(always)]
+    fn index_to_id(id: usize) -> Id {
+        debug_assert!(Id::from_usize(id).is_ok());
+
+        //SAFETY: the e-class wouldn't be representable if a usize didn't fit
+        //        into its id
+        unsafe { Id::from_usize(id).unwrap_unchecked() }
+    }
+
+    /// Makes all represented e-classes their own parents.
+    ///
+    /// This does not free up any space; it only edits e-classes already
+    /// represented.
+    #[inline]
+    pub fn reset(&mut self) {
+        for (child, parent) in self.inner.iter_mut().enumerate() {
+            *parent = Self::index_to_id(child);
+        }
+    }
+
+    /// Removes all represented e-classes from the data structure.
+    ///
+    /// This does not return any used memory to the operating system; it only
+    /// calls [`Vec::clear`].
+    #[inline(always)]
+    pub fn clear(&mut self) {
+        self.inner.clear();
+    }
+
+    /// Shrinks the internal data structure so that the minimum amount of
+    /// memory is used.
+    #[inline(always)]
+    pub fn shrink_to_fit(&mut self) {
+        self.inner.shrink_to_fit();
+    }
+
+    /// Returns the id of the represented e-class with the greatest id.
+    #[inline]
+    pub fn greatest_represented(&self) -> Option<Id> {
+        if let Some(id) = self.inner.len().checked_sub(1) {
+            return Some(Self::index_to_id(id));
+        }
+        None
+    }
+
+    /// Indicates if no e-classes are represented.
+    #[inline(always)]
+    pub fn is_empty(&self) -> bool {
+        self.inner.is_empty()
+    }
+
+    /// Reserves sufficient space in the union-find data structure to contain
+    /// all of the given e-class ids.
+    ///
+    /// # Errors
+    ///
+    /// An error is returned if reserving space for new ids fails.
+    #[inline(always)]
+    pub fn ensure_contains_all(
+        &mut self,
+        ids: impl IntoIterator<Item = Id>,
+    ) -> Result<(), TryReserveError> {
+        if let Some(id) = ids.into_iter().max() {
+            self.ensure_contains(id)?;
+        }
+        Ok(())
+    }
+
+    /// Reserves sufficient space in the union-find data structure to contain
+    /// the given e-class id.
+    ///
+    /// # Errors
+    ///
+    /// An error is returned if reserving space for new ids fails.
+    #[inline]
+    pub fn ensure_contains(&mut self, id: Id) -> Result<(), TryReserveError> {
+        let index = id.into_usize();
+
+        if index < self.inner.len() {
+            return Ok(());
+        }
+
+        let current_size = self.inner.len();
+
+        //NOTE: we just established that index is at least equal to the
+        //      length of the parent vector and `saturating_add` will prevent
+        //      roll-over to 0 from `usize::MAX`
+        self.inner.try_reserve(
+            index.wrapping_sub(current_size).saturating_add(1),
+        )?;
+
+        for (parent, child) in self
+            .inner
+            .spare_capacity_mut()
+            .iter_mut()
+            .zip(current_size..=index)
+        {
+            parent.write(Self::index_to_id(child));
+        }
+
+        //NOTE: we know by the requirement that vectors (containing non-ZSTs)
+        //      may have at most a size of `isize::MAX` that our index is
+        //      nowhere near `usize::MAX`
+        //SAFETY: we just wrote to the unused capacity
+        unsafe {
+            self.inner.set_len(index.wrapping_add(1));
+        }
+
+        Ok(())
+    }
+
+    /// Locates the canonical representation of the given e-class without
+    /// performing path compression.
+    #[inline]
+    pub fn find(&self, id: Id) -> Id {
+        if id.into_usize() >= self.inner.len() {
+            return id;
+        }
+
+        //NOTE: we know all further indices are valid from this point on
+
+        let mut current = id;
+        loop {
+            //SAFETY: refer to the above note
+            let parent =
+                *unsafe { self.inner.get_unchecked(current.into_usize()) };
+            if current == parent {
+                break current;
+            }
+            current = parent;
+        }
+    }
+
+    /// Indicates if two e-class ids share the same parent.
+    #[inline(always)]
+    pub fn equivalent(&self, a: Id, b: Id) -> bool {
+        self.find(a) == self.find(b)
+    }
+
+    /// Indicates if two e-class ids share the same parent, performing path
+    /// compression.
+    #[inline(always)]
+    pub fn equivalent_mut(&mut self, a: Id, b: Id) -> bool {
+        self.find_mut(a) == self.find_mut(b)
+    }
+
+    /// Locates the canonical representation of the given e-class, performing
+    /// path compression.
+    #[inline]
+    pub fn find_mut(&mut self, id: Id) -> Id {
+        if id.into_usize() >= self.inner.len() {
+            return id;
+        }
+
+        let mut current = id;
+        let base_ptr = self.inner.as_mut_ptr();
+
+        //SAFETY: we know these indices are in bounds after the check above
+        //        and due to invariants on the contents of the vector
+        unsafe {
+            loop {
+                let parent_ptr = base_ptr.add(current.into_usize());
+                let parent = *parent_ptr;
+
+                if current == parent {
+                    break current;
+                }
+
+                let grandparent = *base_ptr.add(parent.into_usize());
+                parent_ptr.write(grandparent);
+                current = grandparent;
+            }
+        }
+    }
+
+    /* TODO: implement this and the data structure it will use
+    /// Locates the canonical representation of several e-classes, performing path compression.
+    #[inline]
+    pub fn find_mut_all(&mut self, ids: impl IntoIterator<Item = Id>) -> Result<FindMutAll<Id>, TryReserveError> {
+        //NOTE: needs a try_extend or something to fallibly collect the iterator into a new vec and sort it. we can do this but it'd be more work in the utilities file
+    }
+    */
+
+    /// Merges the given e-classes and returns their representatives,
+    /// indicating which is the parent of the other.
+    ///
+    /// If either is not present and they differ, space is reserved for them
+    /// both.
+    ///
+    /// # Errors
+    ///
+    /// An error is returned if reserving space for new ids fails.
+    #[inline]
+    pub fn merge(
+        &mut self,
+        a: Id,
+        b: Id,
+    ) -> Result<Union<Id>, TryReserveError> {
+        let a = self.find_mut(a);
+        let b = self.find_mut(b);
+
+        let [parent, child] = cmp::minmax(a, b);
+
+        if parent != child {
+            self.ensure_contains(child)?;
+
+            //SAFETY: because of invariants on the inner vector, and due to
+            //        us reserving space if the id is not already contained,
+            //        this is safe
+            *unsafe { self.inner.get_unchecked_mut(child.into_usize()) } =
+                parent;
+        }
+
+        Ok(Union { parent, child })
+    }
+}
+
+impl<Id> Default for UnionFind<Id, Global> {
+    fn default() -> Self {
+        Self::new()
+    }
+}
+
+/// Result of merging two e-classes.
+#[derive(Debug, Copy, Clone, PartialEq, Eq)]
+pub struct Union<Id> {
+    /// The parent e-class.
+    pub parent: Id,
+
+    /// The child e-class.
+    pub child: Id,
+}
+
+impl<Id> Union<Id>
+where
+    Id: Identifier,
+{
+    /// Indicates if an e-class merge was made in this union operation.
+    #[must_use]
+    #[inline(always)]
+    pub fn changed(&self) -> bool {
+        self.parent != self.child
+    }
+}
+
+/// Iterator over canonical representations of e-classes.
+pub struct FindMutAll<'a, Id, A>
+where
+    A: Allocator + Clone,
+{
+    /// The union find we are locating canonical representations within.
+    pub(crate) union_find: &'a mut UnionFind<Id, A>,
+
+    /// The `Id`s we are finding canonical representations of.
+    pub(crate) ids: Vec<Id, A>,
+}
+
+#[cfg(all(test, feature = "core-error"))]
+mod test {
+    use super::*;
+
+    #[test]
+    fn union_find_reserve() {
+        let mut union_find = UnionFind::default();
+
+        union_find
+            .ensure_contains(5usize)
+            .expect("unable to reserve space to contain an identifier");
+
+        assert_eq!(&[0, 1, 2, 3, 4, 5], union_find.inner.as_slice());
+
+        union_find
+            .ensure_contains(6usize)
+            .expect("unable to reserve space to contain an identifier");
+
+        assert_eq!(&[0, 1, 2, 3, 4, 5, 6], union_find.inner.as_slice());
+
+        union_find
+            .ensure_contains(5usize)
+            .expect("unable to reserve space to contain an identifier");
+
+        assert_eq!(&[0, 1, 2, 3, 4, 5, 6], union_find.inner.as_slice());
+
+        union_find
+            .ensure_contains(0usize)
+            .expect("unable to reserve space to contain an identifier");
+
+        assert_eq!(&[0, 1, 2, 3, 4, 5, 6], union_find.inner.as_slice());
+
+        union_find
+            .ensure_contains(usize::MAX)
+            .expect_err("reserving greater than `isize::MAX` should fail");
+    }
+}