diff options
| author | superwhiskers <[email protected]> | 2025-09-15 13:38:14 -0500 |
|---|---|---|
| committer | superwhiskers <[email protected]> | 2026-01-04 22:23:01 -0600 |
| commit | e12b1f4459aee80ee333e90e3b56a3b09f81ae3e (patch) | |
| tree | 872402360b490c992bb0d7e071ab2834adeae03e /crates/core/src/expressions/mod.rs | |
| parent | 50192cbe91da765d3c09cf8e12f328b721d3cb46 (diff) | |
| download | azimuth-e12b1f4459aee80ee333e90e3b56a3b09f81ae3e.tar.gz azimuth-e12b1f4459aee80ee333e90e3b56a3b09f81ae3e.tar.bz2 azimuth-e12b1f4459aee80ee333e90e3b56a3b09f81ae3e.zip | |
node topological sorting
Change-Id: I6a6a6964255d818be1bf9a8f4ec9e317befa19c5
Diffstat (limited to 'crates/core/src/expressions/mod.rs')
| -rw-r--r-- | crates/core/src/expressions/mod.rs | 429 |
1 files changed, 360 insertions, 69 deletions
diff --git a/crates/core/src/expressions/mod.rs b/crates/core/src/expressions/mod.rs index 902ead1..5ffb5a6 100644 --- a/crates/core/src/expressions/mod.rs +++ b/crates/core/src/expressions/mod.rs @@ -7,10 +7,16 @@ use alloc::{ vec, vec::Vec, }; -use core::{hint, iter::TrustedLen}; +use core::{hint, iter::TrustedLen, mem::MaybeUninit}; use generativity::{Guard, Id}; -use crate::numerics::rational::Rational; +use crate::{ + numerics::rational::Rational, + utilities::{ + CheckedArithmeticExt, IntegerOverflowError, VecMemoryExt, + WrappingArithmeticExt, + }, +}; #[cfg(feature = "core-fmt")] use core::fmt; @@ -18,15 +24,11 @@ use core::fmt; #[cfg(feature = "core-error")] use core::error; -//TODO: consider topological sorting of the expression's internals as this -// may be better for performance. -//TODO: consider garbage collection of unreachable nodes (and their -// children). -//TODO: if we add garbage collection, support marking a node for deletion -// when garbage collection is performed. -//TODO: consider renaming `OpenExpression` to something else. -//TODO: add more thorough documentation, along the lines of the standard -// library's `Vec`, given how this is a very fundamental type. +//TODO: support marking a node for deletion when garbage collection is +// performed. +//TODO: consider btree/hashmap for interned strings to speed up the process +// of locating the intern. +//TODO: make topological sorting in-place. /// Representation of a mathematical expression. /// @@ -54,10 +56,14 @@ where /// Root node of the expression. pub(crate) root: Option<NodeIndexInternal>, + + /// Allocator used by the [`Expression`]. + pub(crate) alloc: A, } impl Expression<Global> { /// Constructs an empty expression. + #[must_use] pub fn new() -> Self { Self::new_in(Global) } @@ -68,21 +74,80 @@ where A: Allocator + Clone, { /// Constructs an empty expression using the provided [`Allocator`]. + #[must_use] pub fn new_in(alloc: A) -> Self { Self { inner: Vec::new_in(alloc.clone()), - interns: Vec::new_in(alloc), + interns: Vec::new_in(alloc.clone()), root: None, + alloc, } } /// Clears the expression, removing all nodes and associated information. + /// + /// This does not free up unused space. To do that, use + /// [`Expression::shrink_to_fit`] after calling this. pub fn clear(&mut self) { self.inner.clear(); self.interns.clear(); self.root = None; } + /// Sorts and garbage collects the node list. + /// + /// This may not necessarily free the unused memory. To ensure that + /// happens, call [`Expression::shrink_to_fit`] after calling this. + /// + /// # Errors + /// + /// An error is returned if reserving space for the permutation vector + /// failed or if a cycle was found. + /// + /// The latter should not happen, but is enumerated in possible errors for + /// testing purposes. Prefer crashing if this happens. + pub fn sort(mut self) -> Result<Self, Error> { + if let Some(root) = self.root { + let (permutation, n_reachable) = permutation_vector_of( + self.alloc.clone(), + root, + self.inner.as_slice(), + )?; + + let mut inner = + Vec::with_capacity_in(n_reachable, self.alloc.clone()); + + let start = inner.as_mut_ptr() as *mut MaybeUninit<NodeInternal>; + for (node, position) in self.inner.into_iter().zip(permutation) { + if position != usize::MAX { + //SAFETY: positions are guaranteed to be less than + // `n_reachable` as `n_reachable` would have been + // the next position, all indices are unique + unsafe { + start.add(position).write(MaybeUninit::new(node)); + } + } + } + + //SAFETY: we sized the vector to contain exactly the number of + // reachable nodes + unsafe { inner.set_len(n_reachable) }; + + self.inner = inner; + } else { + self.clear(); + } + + Ok(self) + } + + /// Shrinks the capacity of the node list and intern list as much as + /// possible. + pub fn shrink_to_fit(&mut self) { + self.inner.shrink_to_fit(); + self.interns.shrink_to_fit(); + } + /// Opens up the expression for additive operations and indexing. /// /// While in this state, no destructive operations may be done, such as @@ -101,10 +166,9 @@ where &mut self, node: NodeInternal, ) -> Result<NodeIndexInternal, Error> { - self.inner.try_reserve(1)?; - //SAFETY: we just reserved space for this element - unsafe { self.inner.push_within_capacity(node).unwrap_unchecked() } - Ok(NodeIndexInternal(self.inner.len() - 1)) + let index = self.inner.len(); + self.inner.try_push(node)?; + Ok(NodeIndexInternal(index)) } /// Inserts a [`&str`] into the intern list if it doesn't already exist @@ -124,14 +188,9 @@ where return Ok(StringIndexInternal(i)); } - self.interns.try_reserve(1)?; - //SAFETY: we just reserved space for this element - unsafe { - self.interns - .push_within_capacity(string.as_ref().into()) - .unwrap_unchecked() - } - Ok(StringIndexInternal(self.interns.len() - 1)) + let index = self.interns.len(); + self.interns.try_push(string.as_ref().into())?; + Ok(StringIndexInternal(index)) } /// Returns the number of nodes within this expression. @@ -175,6 +234,9 @@ where while let Some((node_index, state)) = stack.pop() { //SAFETY: all internal indices are guaranteed to be valid + //NOTE: this lint is ignored because we have a state machine + // test which exercises this wildcard arm + #[allow(clippy::wildcard_enum_match_arm)] match *unsafe { self.inner.get_unchecked(node_index.0) } { NodeInternal::BinaryOperation(_, _, _) if state == State::Close => @@ -243,7 +305,11 @@ where self.interns.get_unchecked(s) })?; } - _ => unreachable!(), + _ => + { + #![allow(clippy::unreachable)] + unreachable!() + } } } } @@ -347,8 +413,15 @@ where ) .into()) } else { + //NOTE: used to roll back the vector if an invalid child node is + // found in the arguments + let rollback_size = self.0.inner.len(); + let number_of_joins = n_arguments.saturating_sub(1); - self.0.inner.try_reserve(number_of_joins + 1)?; + self.0.inner.try_reserve(number_of_joins.errored_add(1)?)?; + + //NOTE: we define this beforehand to avoid arithmetic + let operation_index = NodeIndexInternal(self.0.inner.len()); //SAFETY: the iterator is required to implement `TrustedLen` so // we will have reserved enough space for these in addition @@ -381,30 +454,39 @@ where // more slot to hold an `NodeInternal::Join` _ => NodeInternal::Operation( operator, - Some(NodeIndexInternal(self.0.inner.len() + 1)), + //NOTE: `Vec::len` is bounded by `isize::MAX` + // for non-ZSTs + Some(NodeIndexInternal( + self.0.inner.len().wrapping_add(1), + )), ), }) - .unwrap_unchecked() + .unwrap_unchecked(); } - let operation_index = NodeIndexInternal(self.0.inner.len() - 1); + + //NOTE: we know by this point that the number of arguments is at + // least 3 for the code that uses it + let final_index = n_arguments.wrapping_sub(1); //NOTE: we are exclusively inserting `NodeInternal::Join`s now for (i, argument) in arguments.enumerate() { + //NOTE: this check is necessary because it is possible for a + // consumer to get a lifetime-branded index to a `Join` + // node if matches!( //SAFETY: lifetime branding ensures this index is valid unsafe { self.0.inner.get_unchecked(argument.0) }, NodeInternal::Join(_, _) ) { - //NOTE: leaving the inner vector in this state is not - // invalid as nothing will be able to reference any - // of the slots allocated already - // - // when it is added, garbage collection will clean up - // after this if it happens + //NOTE: this is necessary to avoid any invalid indices + // being present in the vector when we perform + // topological sort + self.0.inner.truncate(rollback_size); + return Err(Error::InvalidChild); } - if i != n_arguments - 1 { + if i != final_index { //SAFETY: we know that we have reserved enough space for // each join by this point unsafe { @@ -412,18 +494,23 @@ where .inner .push_within_capacity(NodeInternal::Join( argument.into(), - NodeIndexInternal(self.0.inner.len() + 1), + //NOTE: `Vec::len` is bounded by + // `isize::MAX` for non-ZSTs + NodeIndexInternal( + self.0.inner.len().wrapping_add(1), + ), )) - .unwrap_unchecked() + .unwrap_unchecked(); } } else { - let last_index = self.0.inner.len() - 1; + //NOTE: this is okay because we added the previous index + let previous_index = self.0.inner.len().wrapping_sub(1); if let &mut NodeInternal::Join(_, ref mut second) = //SAFETY: we know that if we are at this point, // the current last element exists unsafe { - self.0.inner.get_unchecked_mut(last_index) + self.0.inner.get_unchecked_mut(previous_index) } { *second = argument.into(); @@ -451,12 +538,12 @@ where /// Returns the number of nodes within this expression. pub fn len(&self) -> usize { - self.0.inner.len() + self.0.len() } /// Indicates if the expression has no nodes. pub fn is_empty(&self) -> bool { - self.0.inner.is_empty() + self.0.is_empty() } /// Gets the root node of the expression, if it exists. @@ -565,23 +652,12 @@ impl<'a, 'id> Iterator for Children<'a, 'id> { let current = unsafe { self.inner.get_unchecked(current.0) }; Some( ( + //NOTE: this lint is ignored because we have a state + // machine test which exercises this wildcard arm + #[allow(clippy::wildcard_enum_match_arm)] match *current { - NodeInternal::Operation(_, Some(index)) => { - //SAFETY: all internal indices are guaranteed to - // be valid - let next = - unsafe { self.inner.get_unchecked(index.0) }; - - if let NodeInternal::Join(child_index, _) = *next - { - self.current = Some(index); - child_index - } else { - self.current = None; - index - } - } - NodeInternal::Join(_, index) => { + NodeInternal::Join(_, index) + | NodeInternal::Operation(_, Some(index)) => { //SAFETY: all internal indices are guaranteed to // be valid let next = @@ -606,7 +682,11 @@ impl<'a, 'id> Iterator for Children<'a, 'id> { self.current = None; index } - _ => unreachable!(), + _ => + { + #![allow(clippy::unreachable)] + unreachable!() + } }, self.id, ) @@ -623,7 +703,7 @@ impl<'a, 'id> Iterator for Children<'a, 'id> { #[derive(Eq, PartialEq, Copy, Clone, Hash)] #[cfg_attr(feature = "core-fmt", derive(Debug))] //NOTE: false positive, see https://github.com/CAD97/generativity/issues/13 -#[allow(repr_transparent_external_private_fields)] +#[allow(repr_transparent_non_zst_fields)] pub struct NodeIndex<'id>(usize, Id<'id>); impl<'id> From<(NodeIndexInternal, Id<'id>)> for NodeIndex<'id> { @@ -652,7 +732,7 @@ impl From<NodeIndex<'_>> for NodeIndexInternal { #[derive(Eq, PartialEq, Copy, Clone, Hash)] #[cfg_attr(feature = "core-fmt", derive(Debug))] //NOTE: false positive, see https://github.com/CAD97/generativity/issues/13 -#[allow(repr_transparent_external_private_fields)] +#[allow(repr_transparent_non_zst_fields)] pub struct StringIndex<'id>(usize, Id<'id>); impl<'id> From<(StringIndexInternal, Id<'id>)> for StringIndex<'id> { @@ -798,19 +878,189 @@ impl fmt::Display for ScalarValue { } } +/// Returns a permutation vector for the given list of [`Expression`] nodes +/// and the number of reachable nodes. +/// +/// Applying this permutation vector to the node list will topologically sort +/// it. Nodes without any incoming edges will be marked with `usize::MAX` in +/// the permutation vector. The root node is ignored as it is not expected to +/// have an incoming edge. This is acceptable as given our constraints, it is +/// not possible for an index to be greater than `isize::MAX`. +/// +/// # Errors +/// +/// An error is returned if a cycle was found or if reserving space for the +/// permutation vector failed. +fn permutation_vector_of<A>( + alloc: A, + root: NodeIndexInternal, + vertices: &[NodeInternal], +) -> Result<(Vec<usize, A>, usize), Error> +where + A: Allocator + Clone, +{ + let mut indegree = + Vec::try_with_capacity_in(vertices.len(), alloc.clone())?; + indegree.resize(vertices.len(), 0); + + //NOTE: used to queue nodes for reachability checks and Kahn's algorithm + let mut worklist = Vec::try_with_capacity_in(1, alloc.clone())?; + + //SAFETY: we just reserved space for this element + unsafe { + worklist.push_within_capacity(root.0).unwrap_unchecked(); + } + + //NOTE: maps indices to their index in a topologically sorted node + // vector. indices with the value of `usize::MAX` are not + // reachable and should be removed by garbage collection. during the + // reachability check, nodes which are reachable are mapped to + // themselves as this is harmless and can be used to skip seen nodes + let mut permutation = Vec::try_with_capacity_in(vertices.len(), alloc)?; + permutation.resize(vertices.len(), usize::MAX); + + let mut n_reachable = 0usize; + while let Some(current) = worklist.pop() { + //SAFETY: we have reserved enough space to track the permuted index + // of all valid indices + let permuted_index = + unsafe { permutation.get_unchecked_mut(current) }; + if *permuted_index == current { + continue; + } + *permuted_index = current; + + //NOTE: this is okay because allocation sizes are limited to + // `isize::MAX`; + n_reachable.wrapping_increment_mut(); + + //SAFETY: all indices in the vertex list are valid + #[allow(clippy::wildcard_enum_match_arm)] + match unsafe { vertices.get_unchecked(current) } { + NodeInternal::Operation(_, Some(NodeIndexInternal(i))) => { + //SAFETY: we have reserved enough space to track the indegree + // of all valid indices + unsafe { indegree.get_unchecked_mut(*i) } + .wrapping_increment_mut(); + + worklist.try_push(*i)?; + } + NodeInternal::Join( + NodeIndexInternal(i), + NodeIndexInternal(j), + ) + | NodeInternal::BinaryOperation( + _, + NodeIndexInternal(i), + NodeIndexInternal(j), + ) => { + //SAFETY: ditto + unsafe { indegree.get_unchecked_mut(*i) } + .wrapping_increment_mut(); + + worklist.try_push(*i)?; + + //SAFETY: ditto + unsafe { indegree.get_unchecked_mut(*j) } + .wrapping_increment_mut(); + + worklist.try_push(*j)?; + } + _ => (), + } + } + + //NOTE: tracks the location of the cursor within the permuted vector and + // serves as a count of the nodes which we have permuted + let mut permuted_position = 0; + + //NOTE: double check there's no cycle that goes to root. + //SAFETY: we have reserved enough space to track the indegree of all + // valid indices + if *unsafe { indegree.get_unchecked(root.0) } != 0 { + return Err(Error::Cycle); + } + + worklist.try_push(root.0)?; + while let Some(current) = worklist.pop() { + //SAFETY: we constructed the permutation vector so that indexing into + // it with vertex indices is always okay + *unsafe { permutation.get_unchecked_mut(current) } = + permuted_position; + permuted_position.wrapping_increment_mut(); + + //SAFETY: all indices in the vertex list are valid + #[allow(clippy::wildcard_enum_match_arm)] + match unsafe { vertices.get_unchecked(current) } { + NodeInternal::Operation(_, Some(NodeIndexInternal(i))) => { + //SAFETY: we have reserved enough space to track the indegree + // of all valid indices + let indegree_of_i = unsafe { indegree.get_unchecked_mut(*i) }; + + indegree_of_i.wrapping_decrement_mut(); + if *indegree_of_i == 0 { + worklist.try_push(*i)?; + } + } + NodeInternal::Join( + NodeIndexInternal(i), + NodeIndexInternal(j), + ) + | NodeInternal::BinaryOperation( + _, + NodeIndexInternal(i), + NodeIndexInternal(j), + ) => { + //SAFETY: ditto + let indegree_of_i = unsafe { indegree.get_unchecked_mut(*i) }; + + indegree_of_i.wrapping_decrement_mut(); + if *indegree_of_i == 0 { + worklist.try_push(*i)?; + } + + //SAFETY: ditto + let indegree_of_j = unsafe { indegree.get_unchecked_mut(*j) }; + + indegree_of_j.wrapping_decrement_mut(); + if *indegree_of_j == 0 { + worklist.try_push(*j)?; + } + } + _ => (), + } + } + + if n_reachable == permuted_position { + Ok((permutation, n_reachable)) + } else { + Err(Error::Cycle) + } +} + /// Representation of an error that occurred within [`Expression`]. #[non_exhaustive] -#[derive(Eq, PartialEq)] +#[derive(PartialEq, Eq)] #[cfg_attr(feature = "core-fmt", derive(Debug))] pub enum Error { /// Memory reservation error. TryReserveError(TryReserveError), + /// Unhandleable integer overflow error. + IntegerOverflowError, + /// Invalid child node. /// /// This occurs when you attempt to pass an index that points to a /// [`Node::Join`] as a child of an operation. InvalidChild, + + /// A cycle was found in the [`Expression`]. + /// + /// This is impossible, but representable as an error state for testing + /// purposes. + #[doc(hidden)] + Cycle, } impl From<TryReserveError> for Error { @@ -819,12 +1069,20 @@ impl From<TryReserveError> for Error { } } +impl From<IntegerOverflowError> for Error { + fn from(_: IntegerOverflowError) -> Self { + Self::IntegerOverflowError + } +} + #[cfg(feature = "core-fmt")] impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let reason = match self { Error::TryReserveError(_) => "unable to allocate memory", + Error::IntegerOverflowError => "integer overflow", Error::InvalidChild => "invalid child node", + Error::Cycle => "expression contains a cycle", }; f.write_str(reason) } @@ -888,9 +1146,16 @@ mod test { ); expr.set_root_node(product); - let expr = expr.close(); - println!("{}", expr); + let old_length = expr.len(); + let expr = expr.close().sort().expect("unable to sort an expression"); + assert_eq!(old_length, expr.len()); + + make_guard!(h); + let mut expr = expr.open(h); + expr.take_root_node(); + let expr = expr.close().sort().expect("unable to sort an expression"); + assert!(expr.is_empty()); } /// Reference state machine for [`Expression`]s. @@ -1171,10 +1436,17 @@ mod test { type Reference = ExpressionReference; fn init_test( - _ref_state: &<Self::Reference as ReferenceStateMachine>::State, + ref_state: &<Self::Reference as ReferenceStateMachine>::State, ) -> Self::SystemUnderTest { - //TODO: randomly initialize state in the reference and replicate - // it over here. + /*let mut expr = Expression { + inner: Vec::new(), + interns: Vec::new(), + root: ref_state.root.map(NodeIndexInternal), + }; + + for node in ref_state.nodes { + //TODO: finish replicating state over here + }*/ ExpressionWrapper::Closed(Expression::new()) } @@ -1454,7 +1726,14 @@ mod test { } } - //TODO: the graph must never contain a cycle. + //NOTE: ensure string conversion doesn't panic. + let _ = expr.to_string(); + + if let Some(root) = expr.root { + permutation_vector_of(Global, root, expr.inner.as_slice()) + .expect("expression must not contain a cycle"); + } + //TODO: validate the layout of the entire expression against the // reference. } @@ -1475,6 +1754,18 @@ mod test { })] #[test] - fn matches_state_machine(sequential 1..1_000 => ExpressionWrapper); + fn matches_state_machine(sequential 1..=250 => ExpressionWrapper); + } +} + +#[cfg(all(kani, feature = "core-error"))] +mod verification { + use super::*; + + use generativity::make_guard; + + #[kani::proof] + fn simple() { + //TODO: add more of a proof here } } |
