diff options
| author | superwhiskers <[email protected]> | 2025-12-17 21:22:37 -0600 |
|---|---|---|
| committer | superwhiskers <[email protected]> | 2026-01-04 22:23:01 -0600 |
| commit | 54e988aa3d31fb21d3397758f4b71d084e1a1130 (patch) | |
| tree | 8cef7d5a61946a1c90707e60e5022a11022f421d /crates/core/src/egraph/union_find.rs | |
| parent | e12b1f4459aee80ee333e90e3b56a3b09f81ae3e (diff) | |
| download | azimuth-canon.tar.gz azimuth-canon.tar.bz2 azimuth-canon.zip | |
Change-Id: I32b78b3eee68205032591578fca70c366a6a6964
Diffstat (limited to 'crates/core/src/egraph/union_find.rs')
| -rw-r--r-- | crates/core/src/egraph/union_find.rs | 363 |
1 files changed, 360 insertions, 3 deletions
diff --git a/crates/core/src/egraph/union_find.rs b/crates/core/src/egraph/union_find.rs index 8d5422d..daae62c 100644 --- a/crates/core/src/egraph/union_find.rs +++ b/crates/core/src/egraph/union_find.rs @@ -1,6 +1,363 @@ //! Union-find / disjoint-set data structure implementations. -/// Simple union-find implementation. +use alloc::{ + alloc::{Allocator, Global}, + collections::TryReserveError, + vec::Vec, +}; +use core::cmp; + +use crate::id::Identifier; + +//NOTE: maybe we should look into finding a way to bind identifiers to the +// isize cap of vector sizes to avoid the error condition that would +// result from saturating this bound? + +/// Simple union-find implementation using union by min-id. /// -/// Operates according to a union by min-id scheme. -pub struct UnionFind {} +/// All [`Identifier`]s are assumed to be their own parents by default. +#[derive(Clone)] +#[cfg_attr(feature = "core-fmt", derive(Debug))] +pub struct UnionFind<Id, A = Global> +where + A: Allocator + Clone, +{ + /// Vector mapping an e-class to its parent e-class. + /// + /// # Invariants + /// + /// For every stored parent `p`, `p.into_usize() < self.inner.len()`; and + /// for every `i < self.inner.len()`, `Id::from_usize(i)` succeeds. + pub(crate) inner: Vec<Id, A>, + + /// Allocator used by the [`UnionFind`]. + pub(crate) alloc: A, +} + +impl<Id> UnionFind<Id, Global> { + /// Constructs an empty union-find data structure. + #[must_use] + #[inline] + pub fn new() -> Self { + Self::new_in(Global) + } +} + +impl<Id, A> UnionFind<Id, A> +where + A: Allocator + Clone, +{ + /// Constructs an empty union-find data structure using the provided + /// [`Allocator`]. + #[must_use] + #[inline] + pub fn new_in(alloc: A) -> Self { + Self { + inner: Vec::new_in(alloc.clone()), + alloc, + } + } +} + +impl<Id, A> UnionFind<Id, A> +where + A: Allocator + Clone, + Id: Identifier, +{ + /// Helper method for converting a `usize` into an [`Identifier`], + /// encoding our core assumption. + #[must_use] + #[inline(always)] + fn index_to_id(id: usize) -> Id { + debug_assert!(Id::from_usize(id).is_ok()); + + //SAFETY: the e-class wouldn't be representable if a usize didn't fit + // into its id + unsafe { Id::from_usize(id).unwrap_unchecked() } + } + + /// Makes all represented e-classes their own parents. + /// + /// This does not free up any space; it only edits e-classes already + /// represented. + #[inline] + pub fn reset(&mut self) { + for (child, parent) in self.inner.iter_mut().enumerate() { + *parent = Self::index_to_id(child); + } + } + + /// Removes all represented e-classes from the data structure. + /// + /// This does not return any used memory to the operating system; it only + /// calls [`Vec::clear`]. + #[inline(always)] + pub fn clear(&mut self) { + self.inner.clear(); + } + + /// Shrinks the internal data structure so that the minimum amount of + /// memory is used. + #[inline(always)] + pub fn shrink_to_fit(&mut self) { + self.inner.shrink_to_fit(); + } + + /// Returns the id of the represented e-class with the greatest id. + #[inline] + pub fn greatest_represented(&self) -> Option<Id> { + if let Some(id) = self.inner.len().checked_sub(1) { + return Some(Self::index_to_id(id)); + } + None + } + + /// Indicates if no e-classes are represented. + #[inline(always)] + pub fn is_empty(&self) -> bool { + self.inner.is_empty() + } + + /// Reserves sufficient space in the union-find data structure to contain + /// all of the given e-class ids. + /// + /// # Errors + /// + /// An error is returned if reserving space for new ids fails. + #[inline(always)] + pub fn ensure_contains_all( + &mut self, + ids: impl IntoIterator<Item = Id>, + ) -> Result<(), TryReserveError> { + if let Some(id) = ids.into_iter().max() { + self.ensure_contains(id)?; + } + Ok(()) + } + + /// Reserves sufficient space in the union-find data structure to contain + /// the given e-class id. + /// + /// # Errors + /// + /// An error is returned if reserving space for new ids fails. + #[inline] + pub fn ensure_contains(&mut self, id: Id) -> Result<(), TryReserveError> { + let index = id.into_usize(); + + if index < self.inner.len() { + return Ok(()); + } + + let current_size = self.inner.len(); + + //NOTE: we just established that index is at least equal to the + // length of the parent vector and `saturating_add` will prevent + // roll-over to 0 from `usize::MAX` + self.inner.try_reserve( + index.wrapping_sub(current_size).saturating_add(1), + )?; + + for (parent, child) in self + .inner + .spare_capacity_mut() + .iter_mut() + .zip(current_size..=index) + { + parent.write(Self::index_to_id(child)); + } + + //NOTE: we know by the requirement that vectors (containing non-ZSTs) + // may have at most a size of `isize::MAX` that our index is + // nowhere near `usize::MAX` + //SAFETY: we just wrote to the unused capacity + unsafe { + self.inner.set_len(index.wrapping_add(1)); + } + + Ok(()) + } + + /// Locates the canonical representation of the given e-class without + /// performing path compression. + #[inline] + pub fn find(&self, id: Id) -> Id { + if id.into_usize() >= self.inner.len() { + return id; + } + + //NOTE: we know all further indices are valid from this point on + + let mut current = id; + loop { + //SAFETY: refer to the above note + let parent = + *unsafe { self.inner.get_unchecked(current.into_usize()) }; + if current == parent { + break current; + } + current = parent; + } + } + + /// Indicates if two e-class ids share the same parent. + #[inline(always)] + pub fn equivalent(&self, a: Id, b: Id) -> bool { + self.find(a) == self.find(b) + } + + /// Indicates if two e-class ids share the same parent, performing path + /// compression. + #[inline(always)] + pub fn equivalent_mut(&mut self, a: Id, b: Id) -> bool { + self.find_mut(a) == self.find_mut(b) + } + + /// Locates the canonical representation of the given e-class, performing + /// path compression. + #[inline] + pub fn find_mut(&mut self, id: Id) -> Id { + if id.into_usize() >= self.inner.len() { + return id; + } + + let mut current = id; + let base_ptr = self.inner.as_mut_ptr(); + + //SAFETY: we know these indices are in bounds after the check above + // and due to invariants on the contents of the vector + unsafe { + loop { + let parent_ptr = base_ptr.add(current.into_usize()); + let parent = *parent_ptr; + + if current == parent { + break current; + } + + let grandparent = *base_ptr.add(parent.into_usize()); + parent_ptr.write(grandparent); + current = grandparent; + } + } + } + + /* TODO: implement this and the data structure it will use + /// Locates the canonical representation of several e-classes, performing path compression. + #[inline] + pub fn find_mut_all(&mut self, ids: impl IntoIterator<Item = Id>) -> Result<FindMutAll<Id>, TryReserveError> { + //NOTE: needs a try_extend or something to fallibly collect the iterator into a new vec and sort it. we can do this but it'd be more work in the utilities file + } + */ + + /// Merges the given e-classes and returns their representatives, + /// indicating which is the parent of the other. + /// + /// If either is not present and they differ, space is reserved for them + /// both. + /// + /// # Errors + /// + /// An error is returned if reserving space for new ids fails. + #[inline] + pub fn merge( + &mut self, + a: Id, + b: Id, + ) -> Result<Union<Id>, TryReserveError> { + let a = self.find_mut(a); + let b = self.find_mut(b); + + let [parent, child] = cmp::minmax(a, b); + + if parent != child { + self.ensure_contains(child)?; + + //SAFETY: because of invariants on the inner vector, and due to + // us reserving space if the id is not already contained, + // this is safe + *unsafe { self.inner.get_unchecked_mut(child.into_usize()) } = + parent; + } + + Ok(Union { parent, child }) + } +} + +impl<Id> Default for UnionFind<Id, Global> { + fn default() -> Self { + Self::new() + } +} + +/// Result of merging two e-classes. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub struct Union<Id> { + /// The parent e-class. + pub parent: Id, + + /// The child e-class. + pub child: Id, +} + +impl<Id> Union<Id> +where + Id: Identifier, +{ + /// Indicates if an e-class merge was made in this union operation. + #[must_use] + #[inline(always)] + pub fn changed(&self) -> bool { + self.parent != self.child + } +} + +/// Iterator over canonical representations of e-classes. +pub struct FindMutAll<'a, Id, A> +where + A: Allocator + Clone, +{ + /// The union find we are locating canonical representations within. + pub(crate) union_find: &'a mut UnionFind<Id, A>, + + /// The `Id`s we are finding canonical representations of. + pub(crate) ids: Vec<Id, A>, +} + +#[cfg(all(test, feature = "core-error"))] +mod test { + use super::*; + + #[test] + fn union_find_reserve() { + let mut union_find = UnionFind::default(); + + union_find + .ensure_contains(5usize) + .expect("unable to reserve space to contain an identifier"); + + assert_eq!(&[0, 1, 2, 3, 4, 5], union_find.inner.as_slice()); + + union_find + .ensure_contains(6usize) + .expect("unable to reserve space to contain an identifier"); + + assert_eq!(&[0, 1, 2, 3, 4, 5, 6], union_find.inner.as_slice()); + + union_find + .ensure_contains(5usize) + .expect("unable to reserve space to contain an identifier"); + + assert_eq!(&[0, 1, 2, 3, 4, 5, 6], union_find.inner.as_slice()); + + union_find + .ensure_contains(0usize) + .expect("unable to reserve space to contain an identifier"); + + assert_eq!(&[0, 1, 2, 3, 4, 5, 6], union_find.inner.as_slice()); + + union_find + .ensure_contains(usize::MAX) + .expect_err("reserving greater than `isize::MAX` should fail"); + } +} |
