about summary refs log tree commit diff stats
path: root/crates/core/src/expressions/mod.rs
diff options
context:
space:
mode:
Diffstat (limited to 'crates/core/src/expressions/mod.rs')
-rw-r--r--crates/core/src/expressions/mod.rs429
1 files changed, 360 insertions, 69 deletions
diff --git a/crates/core/src/expressions/mod.rs b/crates/core/src/expressions/mod.rs
index 902ead1..5ffb5a6 100644
--- a/crates/core/src/expressions/mod.rs
+++ b/crates/core/src/expressions/mod.rs
@@ -7,10 +7,16 @@ use alloc::{
     vec,
     vec::Vec,
 };
-use core::{hint, iter::TrustedLen};
+use core::{hint, iter::TrustedLen, mem::MaybeUninit};
 use generativity::{Guard, Id};
 
-use crate::numerics::rational::Rational;
+use crate::{
+    numerics::rational::Rational,
+    utilities::{
+        CheckedArithmeticExt, IntegerOverflowError, VecMemoryExt,
+        WrappingArithmeticExt,
+    },
+};
 
 #[cfg(feature = "core-fmt")]
 use core::fmt;
@@ -18,15 +24,11 @@ use core::fmt;
 #[cfg(feature = "core-error")]
 use core::error;
 
-//TODO: consider topological sorting of the expression's internals as this
-//      may be better for performance.
-//TODO: consider garbage collection of unreachable nodes (and their
-//      children).
-//TODO: if we add garbage collection, support marking a node for deletion
-//      when garbage collection is performed.
-//TODO: consider renaming `OpenExpression` to something else.
-//TODO: add more thorough documentation, along the lines of the standard
-//      library's `Vec`, given how this is a very fundamental type.
+//TODO: support marking a node for deletion when garbage collection is
+//      performed.
+//TODO: consider btree/hashmap for interned strings to speed up the process
+//      of locating the intern.
+//TODO: make topological sorting in-place.
 
 /// Representation of a mathematical expression.
 ///
@@ -54,10 +56,14 @@ where
 
     /// Root node of the expression.
     pub(crate) root: Option<NodeIndexInternal>,
+
+    /// Allocator used by the [`Expression`].
+    pub(crate) alloc: A,
 }
 
 impl Expression<Global> {
     /// Constructs an empty expression.
+    #[must_use]
     pub fn new() -> Self {
         Self::new_in(Global)
     }
@@ -68,21 +74,80 @@ where
     A: Allocator + Clone,
 {
     /// Constructs an empty expression using the provided [`Allocator`].
+    #[must_use]
     pub fn new_in(alloc: A) -> Self {
         Self {
             inner: Vec::new_in(alloc.clone()),
-            interns: Vec::new_in(alloc),
+            interns: Vec::new_in(alloc.clone()),
             root: None,
+            alloc,
         }
     }
 
     /// Clears the expression, removing all nodes and associated information.
+    ///
+    /// This does not free up unused space. To do that, use
+    /// [`Expression::shrink_to_fit`] after calling this.
     pub fn clear(&mut self) {
         self.inner.clear();
         self.interns.clear();
         self.root = None;
     }
 
+    /// Sorts and garbage collects the node list.
+    ///
+    /// This may not necessarily free the unused memory. To ensure that
+    /// happens, call [`Expression::shrink_to_fit`] after calling this.
+    ///
+    /// # Errors
+    ///
+    /// An error is returned if reserving space for the permutation vector
+    /// failed or if a cycle was found.
+    ///
+    /// The latter should not happen, but is enumerated in possible errors for
+    /// testing purposes. Prefer crashing if this happens.
+    pub fn sort(mut self) -> Result<Self, Error> {
+        if let Some(root) = self.root {
+            let (permutation, n_reachable) = permutation_vector_of(
+                self.alloc.clone(),
+                root,
+                self.inner.as_slice(),
+            )?;
+
+            let mut inner =
+                Vec::with_capacity_in(n_reachable, self.alloc.clone());
+
+            let start = inner.as_mut_ptr() as *mut MaybeUninit<NodeInternal>;
+            for (node, position) in self.inner.into_iter().zip(permutation) {
+                if position != usize::MAX {
+                    //SAFETY: positions are guaranteed to be less than
+                    //        `n_reachable` as `n_reachable` would have been
+                    //        the next position, all indices are unique
+                    unsafe {
+                        start.add(position).write(MaybeUninit::new(node));
+                    }
+                }
+            }
+
+            //SAFETY: we sized the vector to contain exactly the number of
+            //        reachable nodes
+            unsafe { inner.set_len(n_reachable) };
+
+            self.inner = inner;
+        } else {
+            self.clear();
+        }
+
+        Ok(self)
+    }
+
+    /// Shrinks the capacity of the node list and intern list as much as
+    /// possible.
+    pub fn shrink_to_fit(&mut self) {
+        self.inner.shrink_to_fit();
+        self.interns.shrink_to_fit();
+    }
+
     /// Opens up the expression for additive operations and indexing.
     ///
     /// While in this state, no destructive operations may be done, such as
@@ -101,10 +166,9 @@ where
         &mut self,
         node: NodeInternal,
     ) -> Result<NodeIndexInternal, Error> {
-        self.inner.try_reserve(1)?;
-        //SAFETY: we just reserved space for this element
-        unsafe { self.inner.push_within_capacity(node).unwrap_unchecked() }
-        Ok(NodeIndexInternal(self.inner.len() - 1))
+        let index = self.inner.len();
+        self.inner.try_push(node)?;
+        Ok(NodeIndexInternal(index))
     }
 
     /// Inserts a [`&str`] into the intern list if it doesn't already exist
@@ -124,14 +188,9 @@ where
             return Ok(StringIndexInternal(i));
         }
 
-        self.interns.try_reserve(1)?;
-        //SAFETY: we just reserved space for this element
-        unsafe {
-            self.interns
-                .push_within_capacity(string.as_ref().into())
-                .unwrap_unchecked()
-        }
-        Ok(StringIndexInternal(self.interns.len() - 1))
+        let index = self.interns.len();
+        self.interns.try_push(string.as_ref().into())?;
+        Ok(StringIndexInternal(index))
     }
 
     /// Returns the number of nodes within this expression.
@@ -175,6 +234,9 @@ where
 
             while let Some((node_index, state)) = stack.pop() {
                 //SAFETY: all internal indices are guaranteed to be valid
+                //NOTE: this lint is ignored because we have a state machine
+                //      test which exercises this wildcard arm
+                #[allow(clippy::wildcard_enum_match_arm)]
                 match *unsafe { self.inner.get_unchecked(node_index.0) } {
                     NodeInternal::BinaryOperation(_, _, _)
                         if state == State::Close =>
@@ -243,7 +305,11 @@ where
                             self.interns.get_unchecked(s)
                         })?;
                     }
-                    _ => unreachable!(),
+                    _ =>
+                    {
+                        #![allow(clippy::unreachable)]
+                        unreachable!()
+                    }
                 }
             }
         }
@@ -347,8 +413,15 @@ where
             )
                 .into())
         } else {
+            //NOTE: used to roll back the vector if an invalid child node is
+            //      found in the arguments
+            let rollback_size = self.0.inner.len();
+
             let number_of_joins = n_arguments.saturating_sub(1);
-            self.0.inner.try_reserve(number_of_joins + 1)?;
+            self.0.inner.try_reserve(number_of_joins.errored_add(1)?)?;
+
+            //NOTE: we define this beforehand to avoid arithmetic
+            let operation_index = NodeIndexInternal(self.0.inner.len());
 
             //SAFETY: the iterator is required to implement `TrustedLen` so
             //        we will have reserved enough space for these in addition
@@ -381,30 +454,39 @@ where
                         //      more slot to hold an `NodeInternal::Join`
                         _ => NodeInternal::Operation(
                             operator,
-                            Some(NodeIndexInternal(self.0.inner.len() + 1)),
+                            //NOTE: `Vec::len` is bounded by `isize::MAX`
+                            //      for non-ZSTs
+                            Some(NodeIndexInternal(
+                                self.0.inner.len().wrapping_add(1),
+                            )),
                         ),
                     })
-                    .unwrap_unchecked()
+                    .unwrap_unchecked();
             }
-            let operation_index = NodeIndexInternal(self.0.inner.len() - 1);
+
+            //NOTE: we know by this point that the number of arguments is at
+            //      least 3 for the code that uses it
+            let final_index = n_arguments.wrapping_sub(1);
 
             //NOTE: we are exclusively inserting `NodeInternal::Join`s now
             for (i, argument) in arguments.enumerate() {
+                //NOTE: this check is necessary because it is possible for a
+                //      consumer to get a lifetime-branded index to a `Join`
+                //      node
                 if matches!(
                     //SAFETY: lifetime branding ensures this index is valid
                     unsafe { self.0.inner.get_unchecked(argument.0) },
                     NodeInternal::Join(_, _)
                 ) {
-                    //NOTE: leaving the inner vector in this state is not
-                    //      invalid as nothing will be able to reference any
-                    //      of the slots allocated already
-                    //
-                    //      when it is added, garbage collection will clean up
-                    //      after this if it happens
+                    //NOTE: this is necessary to avoid any invalid indices
+                    //      being present in the vector when we perform
+                    //      topological sort
+                    self.0.inner.truncate(rollback_size);
+
                     return Err(Error::InvalidChild);
                 }
 
-                if i != n_arguments - 1 {
+                if i != final_index {
                     //SAFETY: we know that we have reserved enough space for
                     //        each join by this point
                     unsafe {
@@ -412,18 +494,23 @@ where
                             .inner
                             .push_within_capacity(NodeInternal::Join(
                                 argument.into(),
-                                NodeIndexInternal(self.0.inner.len() + 1),
+                                //NOTE: `Vec::len` is bounded by
+                                //      `isize::MAX` for non-ZSTs
+                                NodeIndexInternal(
+                                    self.0.inner.len().wrapping_add(1),
+                                ),
                             ))
-                            .unwrap_unchecked()
+                            .unwrap_unchecked();
                     }
                 } else {
-                    let last_index = self.0.inner.len() - 1;
+                    //NOTE: this is okay because we added the previous index
+                    let previous_index = self.0.inner.len().wrapping_sub(1);
 
                     if let &mut NodeInternal::Join(_, ref mut second) =
                         //SAFETY: we know that if we are at this point,
                         //        the current last element exists
                         unsafe {
-                            self.0.inner.get_unchecked_mut(last_index)
+                            self.0.inner.get_unchecked_mut(previous_index)
                         }
                     {
                         *second = argument.into();
@@ -451,12 +538,12 @@ where
 
     /// Returns the number of nodes within this expression.
     pub fn len(&self) -> usize {
-        self.0.inner.len()
+        self.0.len()
     }
 
     /// Indicates if the expression has no nodes.
     pub fn is_empty(&self) -> bool {
-        self.0.inner.is_empty()
+        self.0.is_empty()
     }
 
     /// Gets the root node of the expression, if it exists.
@@ -565,23 +652,12 @@ impl<'a, 'id> Iterator for Children<'a, 'id> {
             let current = unsafe { self.inner.get_unchecked(current.0) };
             Some(
                 (
+                    //NOTE: this lint is ignored because we have a state
+                    //      machine test which exercises this wildcard arm
+                    #[allow(clippy::wildcard_enum_match_arm)]
                     match *current {
-                        NodeInternal::Operation(_, Some(index)) => {
-                            //SAFETY: all internal indices are guaranteed to
-                            //        be valid
-                            let next =
-                                unsafe { self.inner.get_unchecked(index.0) };
-
-                            if let NodeInternal::Join(child_index, _) = *next
-                            {
-                                self.current = Some(index);
-                                child_index
-                            } else {
-                                self.current = None;
-                                index
-                            }
-                        }
-                        NodeInternal::Join(_, index) => {
+                        NodeInternal::Join(_, index)
+                        | NodeInternal::Operation(_, Some(index)) => {
                             //SAFETY: all internal indices are guaranteed to
                             //        be valid
                             let next =
@@ -606,7 +682,11 @@ impl<'a, 'id> Iterator for Children<'a, 'id> {
                             self.current = None;
                             index
                         }
-                        _ => unreachable!(),
+                        _ =>
+                        {
+                            #![allow(clippy::unreachable)]
+                            unreachable!()
+                        }
                     },
                     self.id,
                 )
@@ -623,7 +703,7 @@ impl<'a, 'id> Iterator for Children<'a, 'id> {
 #[derive(Eq, PartialEq, Copy, Clone, Hash)]
 #[cfg_attr(feature = "core-fmt", derive(Debug))]
 //NOTE: false positive, see https://github.com/CAD97/generativity/issues/13
-#[allow(repr_transparent_external_private_fields)]
+#[allow(repr_transparent_non_zst_fields)]
 pub struct NodeIndex<'id>(usize, Id<'id>);
 
 impl<'id> From<(NodeIndexInternal, Id<'id>)> for NodeIndex<'id> {
@@ -652,7 +732,7 @@ impl From<NodeIndex<'_>> for NodeIndexInternal {
 #[derive(Eq, PartialEq, Copy, Clone, Hash)]
 #[cfg_attr(feature = "core-fmt", derive(Debug))]
 //NOTE: false positive, see https://github.com/CAD97/generativity/issues/13
-#[allow(repr_transparent_external_private_fields)]
+#[allow(repr_transparent_non_zst_fields)]
 pub struct StringIndex<'id>(usize, Id<'id>);
 
 impl<'id> From<(StringIndexInternal, Id<'id>)> for StringIndex<'id> {
@@ -798,19 +878,189 @@ impl fmt::Display for ScalarValue {
     }
 }
 
+/// Returns a permutation vector for the given list of [`Expression`] nodes
+/// and the number of reachable nodes.
+///
+/// Applying this permutation vector to the node list will topologically sort
+/// it. Nodes without any incoming edges will be marked with `usize::MAX` in
+/// the permutation vector. The root node is ignored as it is not expected to
+/// have an incoming edge. This is acceptable as given our constraints, it is
+/// not possible for an index to be greater than `isize::MAX`.
+///
+/// # Errors
+///
+/// An error is returned if a cycle was found or if reserving space for the
+/// permutation vector failed.
+fn permutation_vector_of<A>(
+    alloc: A,
+    root: NodeIndexInternal,
+    vertices: &[NodeInternal],
+) -> Result<(Vec<usize, A>, usize), Error>
+where
+    A: Allocator + Clone,
+{
+    let mut indegree =
+        Vec::try_with_capacity_in(vertices.len(), alloc.clone())?;
+    indegree.resize(vertices.len(), 0);
+
+    //NOTE: used to queue nodes for reachability checks and Kahn's algorithm
+    let mut worklist = Vec::try_with_capacity_in(1, alloc.clone())?;
+
+    //SAFETY: we just reserved space for this element
+    unsafe {
+        worklist.push_within_capacity(root.0).unwrap_unchecked();
+    }
+
+    //NOTE: maps indices to their index in a topologically sorted node
+    //      vector. indices with the value of `usize::MAX` are not
+    //      reachable and should be removed by garbage collection. during the
+    //      reachability check, nodes which are reachable are mapped to
+    //      themselves as this is harmless and can be used to skip seen nodes
+    let mut permutation = Vec::try_with_capacity_in(vertices.len(), alloc)?;
+    permutation.resize(vertices.len(), usize::MAX);
+
+    let mut n_reachable = 0usize;
+    while let Some(current) = worklist.pop() {
+        //SAFETY: we have reserved enough space to track the permuted index
+        //        of all valid indices
+        let permuted_index =
+            unsafe { permutation.get_unchecked_mut(current) };
+        if *permuted_index == current {
+            continue;
+        }
+        *permuted_index = current;
+
+        //NOTE: this is okay because allocation sizes are limited to
+        //      `isize::MAX`;
+        n_reachable.wrapping_increment_mut();
+
+        //SAFETY: all indices in the vertex list are valid
+        #[allow(clippy::wildcard_enum_match_arm)]
+        match unsafe { vertices.get_unchecked(current) } {
+            NodeInternal::Operation(_, Some(NodeIndexInternal(i))) => {
+                //SAFETY: we have reserved enough space to track the indegree
+                //        of all valid indices
+                unsafe { indegree.get_unchecked_mut(*i) }
+                    .wrapping_increment_mut();
+
+                worklist.try_push(*i)?;
+            }
+            NodeInternal::Join(
+                NodeIndexInternal(i),
+                NodeIndexInternal(j),
+            )
+            | NodeInternal::BinaryOperation(
+                _,
+                NodeIndexInternal(i),
+                NodeIndexInternal(j),
+            ) => {
+                //SAFETY: ditto
+                unsafe { indegree.get_unchecked_mut(*i) }
+                    .wrapping_increment_mut();
+
+                worklist.try_push(*i)?;
+
+                //SAFETY: ditto
+                unsafe { indegree.get_unchecked_mut(*j) }
+                    .wrapping_increment_mut();
+
+                worklist.try_push(*j)?;
+            }
+            _ => (),
+        }
+    }
+
+    //NOTE: tracks the location of the cursor within the permuted vector and
+    //      serves as a count of the nodes which we have permuted
+    let mut permuted_position = 0;
+
+    //NOTE: double check there's no cycle that goes to root.
+    //SAFETY: we have reserved enough space to track the indegree of all
+    //        valid indices
+    if *unsafe { indegree.get_unchecked(root.0) } != 0 {
+        return Err(Error::Cycle);
+    }
+
+    worklist.try_push(root.0)?;
+    while let Some(current) = worklist.pop() {
+        //SAFETY: we constructed the permutation vector so that indexing into
+        //        it with vertex indices is always okay
+        *unsafe { permutation.get_unchecked_mut(current) } =
+            permuted_position;
+        permuted_position.wrapping_increment_mut();
+
+        //SAFETY: all indices in the vertex list are valid
+        #[allow(clippy::wildcard_enum_match_arm)]
+        match unsafe { vertices.get_unchecked(current) } {
+            NodeInternal::Operation(_, Some(NodeIndexInternal(i))) => {
+                //SAFETY: we have reserved enough space to track the indegree
+                //        of all valid indices
+                let indegree_of_i = unsafe { indegree.get_unchecked_mut(*i) };
+
+                indegree_of_i.wrapping_decrement_mut();
+                if *indegree_of_i == 0 {
+                    worklist.try_push(*i)?;
+                }
+            }
+            NodeInternal::Join(
+                NodeIndexInternal(i),
+                NodeIndexInternal(j),
+            )
+            | NodeInternal::BinaryOperation(
+                _,
+                NodeIndexInternal(i),
+                NodeIndexInternal(j),
+            ) => {
+                //SAFETY: ditto
+                let indegree_of_i = unsafe { indegree.get_unchecked_mut(*i) };
+
+                indegree_of_i.wrapping_decrement_mut();
+                if *indegree_of_i == 0 {
+                    worklist.try_push(*i)?;
+                }
+
+                //SAFETY: ditto
+                let indegree_of_j = unsafe { indegree.get_unchecked_mut(*j) };
+
+                indegree_of_j.wrapping_decrement_mut();
+                if *indegree_of_j == 0 {
+                    worklist.try_push(*j)?;
+                }
+            }
+            _ => (),
+        }
+    }
+
+    if n_reachable == permuted_position {
+        Ok((permutation, n_reachable))
+    } else {
+        Err(Error::Cycle)
+    }
+}
+
 /// Representation of an error that occurred within [`Expression`].
 #[non_exhaustive]
-#[derive(Eq, PartialEq)]
+#[derive(PartialEq, Eq)]
 #[cfg_attr(feature = "core-fmt", derive(Debug))]
 pub enum Error {
     /// Memory reservation error.
     TryReserveError(TryReserveError),
 
+    /// Unhandleable integer overflow error.
+    IntegerOverflowError,
+
     /// Invalid child node.
     ///
     /// This occurs when you attempt to pass an index that points to a
     /// [`Node::Join`] as a child of an operation.
     InvalidChild,
+
+    /// A cycle was found in the [`Expression`].
+    ///
+    /// This is impossible, but representable as an error state for testing
+    /// purposes.
+    #[doc(hidden)]
+    Cycle,
 }
 
 impl From<TryReserveError> for Error {
@@ -819,12 +1069,20 @@ impl From<TryReserveError> for Error {
     }
 }
 
+impl From<IntegerOverflowError> for Error {
+    fn from(_: IntegerOverflowError) -> Self {
+        Self::IntegerOverflowError
+    }
+}
+
 #[cfg(feature = "core-fmt")]
 impl fmt::Display for Error {
     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
         let reason = match self {
             Error::TryReserveError(_) => "unable to allocate memory",
+            Error::IntegerOverflowError => "integer overflow",
             Error::InvalidChild => "invalid child node",
+            Error::Cycle => "expression contains a cycle",
         };
         f.write_str(reason)
     }
@@ -888,9 +1146,16 @@ mod test {
         );
 
         expr.set_root_node(product);
-        let expr = expr.close();
 
-        println!("{}", expr);
+        let old_length = expr.len();
+        let expr = expr.close().sort().expect("unable to sort an expression");
+        assert_eq!(old_length, expr.len());
+
+        make_guard!(h);
+        let mut expr = expr.open(h);
+        expr.take_root_node();
+        let expr = expr.close().sort().expect("unable to sort an expression");
+        assert!(expr.is_empty());
     }
 
     /// Reference state machine for [`Expression`]s.
@@ -1171,10 +1436,17 @@ mod test {
         type Reference = ExpressionReference;
 
         fn init_test(
-            _ref_state: &<Self::Reference as ReferenceStateMachine>::State,
+            ref_state: &<Self::Reference as ReferenceStateMachine>::State,
         ) -> Self::SystemUnderTest {
-            //TODO: randomly initialize state in the reference and replicate
-            //      it over here.
+            /*let mut expr = Expression {
+                inner: Vec::new(),
+                interns: Vec::new(),
+                root: ref_state.root.map(NodeIndexInternal),
+            };
+
+            for node in ref_state.nodes {
+                //TODO: finish replicating state over here
+            }*/
             ExpressionWrapper::Closed(Expression::new())
         }
 
@@ -1454,7 +1726,14 @@ mod test {
                 }
             }
 
-            //TODO: the graph must never contain a cycle.
+            //NOTE: ensure string conversion doesn't panic.
+            let _ = expr.to_string();
+
+            if let Some(root) = expr.root {
+                permutation_vector_of(Global, root, expr.inner.as_slice())
+                    .expect("expression must not contain a cycle");
+            }
+
             //TODO: validate the layout of the entire expression against the
             //      reference.
         }
@@ -1475,6 +1754,18 @@ mod test {
         })]
 
         #[test]
-        fn matches_state_machine(sequential 1..1_000 => ExpressionWrapper);
+        fn matches_state_machine(sequential 1..=250 => ExpressionWrapper);
+    }
+}
+
+#[cfg(all(kani, feature = "core-error"))]
+mod verification {
+    use super::*;
+
+    use generativity::make_guard;
+
+    #[kani::proof]
+    fn simple() {
+        //TODO: add more of a proof here
     }
 }