//! Common representation of mathematical expressions. use alloc::{ alloc::{Allocator, Global}, collections::TryReserveError, string::String, vec, vec::Vec, }; use core::{hint, iter::TrustedLen, mem::MaybeUninit}; use generativity::{Guard, Id}; use crate::{ numerics::rational::Rational, utilities::{ CheckedArithmeticExt, IntegerOverflowError, VecMemoryExt, WrappingArithmeticExt, }, }; #[cfg(feature = "core-fmt")] use core::fmt; #[cfg(feature = "core-error")] use core::error; //TODO: support marking a node for deletion when garbage collection is // performed. //TODO: consider btree/hashmap for interned strings to speed up the process // of locating the intern. //TODO: make topological sorting in-place. /// Representation of a mathematical expression. /// /// Components of the expression are maintained in a vector where they are /// linked to each other by their indices. Indices always point to another /// existing node. /// /// # Guarantees /// /// - The expression contains no cycles. /// - All references to other nodes are valid. /// - All references to interned strings are valid. /// - No strings are interned more than once. #[derive(Clone)] #[cfg_attr(feature = "core-fmt", derive(Debug))] pub struct Expression where A: Allocator + Clone, { /// List of nodes linked to each other by their indices. pub(crate) inner: Vec, /// List of interned strings. pub(crate) interns: Vec, /// Root node of the expression. pub(crate) root: Option, /// Allocator used by the [`Expression`]. pub(crate) alloc: A, } impl Expression { /// Constructs an empty expression. #[must_use] pub fn new() -> Self { Self::new_in(Global) } } impl Expression where A: Allocator + Clone, { /// Constructs an empty expression using the provided [`Allocator`]. #[must_use] pub fn new_in(alloc: A) -> Self { Self { inner: Vec::new_in(alloc.clone()), interns: Vec::new_in(alloc.clone()), root: None, alloc, } } /// Clears the expression, removing all nodes and associated information. /// /// This does not free up unused space. To do that, use /// [`Expression::shrink_to_fit`] after calling this. pub fn clear(&mut self) { self.inner.clear(); self.interns.clear(); self.root = None; } /// Sorts and garbage collects the node list. /// /// This may not necessarily free the unused memory. To ensure that /// happens, call [`Expression::shrink_to_fit`] after calling this. /// /// # Errors /// /// An error is returned if reserving space for the permutation vector /// failed or if a cycle was found. /// /// The latter should not happen, but is enumerated in possible errors for /// testing purposes. Prefer crashing if this happens. pub fn sort(mut self) -> Result { if let Some(root) = self.root { let (permutation, n_reachable) = permutation_vector_of( self.alloc.clone(), root, self.inner.as_slice(), )?; let mut inner = Vec::with_capacity_in(n_reachable, self.alloc.clone()); let start = inner.as_mut_ptr() as *mut MaybeUninit; for (node, position) in self.inner.into_iter().zip(permutation) { if position != usize::MAX { //SAFETY: positions are guaranteed to be less than // `n_reachable` as `n_reachable` would have been // the next position, all indices are unique unsafe { start.add(position).write(MaybeUninit::new(node)); } } } //SAFETY: we sized the vector to contain exactly the number of // reachable nodes unsafe { inner.set_len(n_reachable) }; self.inner = inner; } else { self.clear(); } Ok(self) } /// Shrinks the capacity of the node list and intern list as much as /// possible. pub fn shrink_to_fit(&mut self) { self.inner.shrink_to_fit(); self.interns.shrink_to_fit(); } /// Opens up the expression for additive operations and indexing. /// /// While in this state, no destructive operations may be done, such as /// garbage collection and sorting of the node list. pub fn open<'id>(self, guard: Guard<'id>) -> OpenExpression<'id, A> { OpenExpression(self, guard.into()) } /// Inserts a [`NodeInternal`] into the node list and returns the index. /// /// # Errors /// /// Returns an error if storage was unable to be reserved for the new /// node. fn insert_node( &mut self, node: NodeInternal, ) -> Result { let index = self.inner.len(); self.inner.try_push(node)?; Ok(NodeIndexInternal(index)) } /// Inserts a [`&str`] into the intern list if it doesn't already exist /// and returns the index. /// /// # Errors /// /// Returns an error if storage was unable to be reserved for the new /// string. fn insert_string( &mut self, string: impl AsRef, ) -> Result { if let Some(i) = self.interns.iter().position(|s| s == string.as_ref()) { return Ok(StringIndexInternal(i)); } let index = self.interns.len(); self.interns.try_push(string.as_ref().into())?; Ok(StringIndexInternal(index)) } /// Returns the number of nodes within this expression. pub fn len(&self) -> usize { self.inner.len() } /// Indicates if the expression has no nodes. pub fn is_empty(&self) -> bool { self.inner.is_empty() } } impl Default for Expression { fn default() -> Self { Self::new() } } #[cfg(feature = "core-fmt")] impl fmt::Display for Expression where A: Allocator + Clone, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { /// State we are in when we are at a node. #[derive(PartialEq, Eq)] enum State { /// First time encountering a node. Enter, /// Second time encountering a node. Visit, /// Third time encountering a node. Close, } if let Some(root_index) = self.root { let mut stack = vec![(root_index, State::Enter)]; while let Some((node_index, state)) = stack.pop() { //SAFETY: all internal indices are guaranteed to be valid //NOTE: this lint is ignored because we have a state machine // test which exercises this wildcard arm #[allow(clippy::wildcard_enum_match_arm)] match *unsafe { self.inner.get_unchecked(node_index.0) } { NodeInternal::BinaryOperation(_, _, _) if state == State::Close => { f.write_str(")")?; } NodeInternal::BinaryOperation( StringIndexInternal(s), _, _, ) if state == State::Visit => { //SAFETY: all internal indices are guaranteed to be // valid write!(f, " {} ", unsafe { self.interns.get_unchecked(s) })?; } NodeInternal::BinaryOperation(_, c1, c2) if state == State::Enter => { f.write_str("(")?; stack.push((node_index, State::Close)); stack.push((c2, State::Enter)); stack.push((node_index, State::Visit)); stack.push((c1, State::Enter)); } NodeInternal::Operation(StringIndexInternal(s), None) => { //SAFETY: all internal indices are guaranteed to be // valid write!(f, "{}()", unsafe { self.interns.get_unchecked(s) })?; } NodeInternal::Operation(_, _) if state == State::Close => { f.write_str(")")?; } NodeInternal::Operation( StringIndexInternal(s), Some(c1), ) => { //SAFETY: all internal indices are guaranteed to be // valid write!(f, "{}(", unsafe { self.interns.get_unchecked(s) })?; stack.push((node_index, State::Close)); stack.push((c1, State::Enter)); } NodeInternal::Join(_, _) if state == State::Visit => { f.write_str(", ")?; } NodeInternal::Join(c1, c2) => { stack.push((c2, State::Enter)); stack.push((node_index, State::Visit)); stack.push((c1, State::Enter)); } NodeInternal::Scalar(ref v) => { write!(f, "{}", v)?; } NodeInternal::Variable(StringIndexInternal(s)) => { //SAFETY: all internal indices are guaranteed to be // valid write!(f, "{}", unsafe { self.interns.get_unchecked(s) })?; } _ => { #![allow(clippy::unreachable)] unreachable!() } } } } Ok(()) } } /// Representation of a mathematical expression. /// /// This is a variant of [`Expression`] in a state where it is able to be /// written to and indexed. #[derive(Clone)] #[cfg_attr(feature = "core-fmt", derive(Debug))] pub struct OpenExpression<'id, A>(Expression, Id<'id>) where A: Allocator + Clone; impl<'id, A> OpenExpression<'id, A> where A: Allocator + Clone, { /// Adds a new scalar value to the expression. /// /// # Errors /// /// Returns an error if storage was unable to be reserved for the new /// scalar value. pub fn insert_scalar( &mut self, value: impl Into, ) -> Result, Error> { Ok(( self.0.insert_node(NodeInternal::Scalar(value.into()))?, self.1, ) .into()) } /// Adds a new variable to the expression. /// /// # Errors /// /// Returns an error if storage was unable to be reserved for the new /// variable. pub fn insert_variable( &mut self, name: impl AsRef, ) -> Result, Error> { let name = self.0.insert_string(name)?; Ok( (self.0.insert_node(NodeInternal::Variable(name))?, self.1) .into(), ) } /// Adds a new operation to the expression. /// /// # Errors /// /// Returns an error if storage was unable to be reserved for the new /// operation or if one of the children of the node would have been a /// [`Node::Join`]. pub fn insert_operation( &mut self, operator: impl AsRef, arguments: Arguments, ) -> Result, Error> where Arguments: IntoIterator>, Arguments::IntoIter: TrustedLen, { let operator = self.0.insert_string(operator)?; let mut arguments = arguments.into_iter(); let n_arguments = arguments.size_hint().0; if n_arguments == 2 { //SAFETY: the iterator is required to implement `TrustedLen` let first = unsafe { arguments.next().unwrap_unchecked() }; //SAFETY: the iterator is required to implement `TrustedLen` let second = unsafe { arguments.next().unwrap_unchecked() }; if matches!( //SAFETY: lifetime branding ensures this index is valid unsafe { self.0.inner.get_unchecked(first.0) }, NodeInternal::Join(_, _) ) || matches!( //SAFETY: lifetime branding ensures this index is valid unsafe { self.0.inner.get_unchecked(second.0) }, NodeInternal::Join(_, _) ) { return Err(Error::InvalidChild); } Ok(( self.0.insert_node(NodeInternal::BinaryOperation( operator, first.into(), second.into(), ))?, self.1, ) .into()) } else { //NOTE: used to roll back the vector if an invalid child node is // found in the arguments let rollback_size = self.0.inner.len(); let number_of_joins = n_arguments.saturating_sub(1); self.0.inner.try_reserve(number_of_joins.errored_add(1)?)?; //NOTE: we define this beforehand to avoid arithmetic let operation_index = NodeIndexInternal(self.0.inner.len()); //SAFETY: the iterator is required to implement `TrustedLen` so // we will have reserved enough space for these in addition // to guaranteeing the noted elements exist // // additionally, because of our lifetime branding, all // indices referenced by lifetime-branded node indices are // valid unsafe { self.0 .inner .push_within_capacity(match n_arguments { 0 => NodeInternal::Operation(operator, None), 1 => { let argument = arguments.next().unwrap_unchecked(); if matches!( self.0.inner.get_unchecked(argument.0), NodeInternal::Join(_, _) ) { return Err(Error::InvalidChild); } NodeInternal::Operation( operator, Some(argument.into()), ) } //NOTE: we know that we will be adding at least one // more slot to hold an `NodeInternal::Join` _ => NodeInternal::Operation( operator, //NOTE: `Vec::len` is bounded by `isize::MAX` // for non-ZSTs Some(NodeIndexInternal( self.0.inner.len().wrapping_add(1), )), ), }) .unwrap_unchecked(); } //NOTE: we know by this point that the number of arguments is at // least 3 for the code that uses it let final_index = n_arguments.wrapping_sub(1); //NOTE: we are exclusively inserting `NodeInternal::Join`s now for (i, argument) in arguments.enumerate() { //NOTE: this check is necessary because it is possible for a // consumer to get a lifetime-branded index to a `Join` // node if matches!( //SAFETY: lifetime branding ensures this index is valid unsafe { self.0.inner.get_unchecked(argument.0) }, NodeInternal::Join(_, _) ) { //NOTE: this is necessary to avoid any invalid indices // being present in the vector when we perform // topological sort self.0.inner.truncate(rollback_size); return Err(Error::InvalidChild); } if i != final_index { //SAFETY: we know that we have reserved enough space for // each join by this point unsafe { self.0 .inner .push_within_capacity(NodeInternal::Join( argument.into(), //NOTE: `Vec::len` is bounded by // `isize::MAX` for non-ZSTs NodeIndexInternal( self.0.inner.len().wrapping_add(1), ), )) .unwrap_unchecked(); } } else { //NOTE: this is okay because we added the previous index let previous_index = self.0.inner.len().wrapping_sub(1); if let &mut NodeInternal::Join(_, ref mut second) = //SAFETY: we know that if we are at this point, // the current last element exists unsafe { self.0.inner.get_unchecked_mut(previous_index) } { *second = argument.into(); } else { //SAFETY: we know by this point that this last // element must be `NodeInternal::Join` unsafe { hint::unreachable_unchecked() } } } } Ok((operation_index, self.1).into()) } } /// Sets the root node of the expression. pub fn set_root_node(&mut self, node: NodeIndex<'id>) { self.0.root = Some(node.into()); } /// Closes the expression to additive operations and being indexed. pub fn close(self) -> Expression { self.0 } /// Returns the number of nodes within this expression. pub fn len(&self) -> usize { self.0.len() } /// Indicates if the expression has no nodes. pub fn is_empty(&self) -> bool { self.0.is_empty() } /// Gets the root node of the expression, if it exists. pub fn root_node(&self) -> Option> { Some((self.0.root?, self.1).into()) } /// Takes the expression's root node. pub fn take_root_node(&mut self) -> Option> { Some((self.0.root.take()?, self.1).into()) } /// Retrieves the indicated node from the expression. pub fn node(&self, index: NodeIndex<'id>) -> Node<'id> { //SAFETY: by the lifetime, this is guaranteed to be a valid index ( unsafe { self.0.inner.get_unchecked(index.0) }.clone(), self.1, ) .into() } /// Retrieves the indicated string from the expression. pub fn string(&self, index: StringIndex<'id>) -> &str { //SAFETY: by the lifetime, this is guaranteed to be a valid index unsafe { self.0.interns.get_unchecked(index.0) } } /// Returns an iterator over the children of the node. pub fn children_of<'a>( &'a self, index: NodeIndex<'id>, ) -> Children<'a, 'id> { Children::new(&self.0.inner, index) } } #[cfg(feature = "core-fmt")] impl fmt::Display for OpenExpression<'_, A> where A: Allocator + Clone, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{}", self.0) } } /// Iterator over the children of a [`Node`]. pub struct Children<'a, 'id> { /// The node we are currently at. /// /// If the node given to the constructor was a [`Node::Scalar`], /// [`Node::Variable`], or a [`Node::Operation`] with `None` as the /// second parameter, this must be `None`, and hence this may never be /// referencing a node of those kinds. current: Option, /// Whether or not we are at the "second" part of the node. /// /// This is only relevant for [`Node::BinaryOperation`] as it is not /// clear which of the children are being referenced. second: bool, /// Reference to the internal storage of the [`Expression`]. inner: &'a [NodeInternal], /// Lifetime tying it to the open state of the [`Expression`]. id: Id<'id>, } impl<'a, 'id> Children<'a, 'id> { /// Creates a new iterator over the children of a [`Node`]. fn new(inner: &'a [NodeInternal], index: NodeIndex<'id>) -> Self { let id = index.1; if matches!( //SAFETY: by the lifetime, this is guaranteed to be a valid // index unsafe { inner.get_unchecked(index.0) }, NodeInternal::Scalar(_) | NodeInternal::Variable(_) | NodeInternal::Operation(_, None) ) { Self { current: None, second: false, inner, id, } } else { Self { current: Some(index.into()), second: false, inner, id, } } } } impl<'a, 'id> Iterator for Children<'a, 'id> { type Item = NodeIndex<'id>; fn next(&mut self) -> Option { if let Some(current) = self.current { //SAFETY: by the lifetime, this is guaranteed to be a valid index let current = unsafe { self.inner.get_unchecked(current.0) }; Some( ( //NOTE: this lint is ignored because we have a state // machine test which exercises this wildcard arm #[allow(clippy::wildcard_enum_match_arm)] match *current { NodeInternal::Join(_, index) | NodeInternal::Operation(_, Some(index)) => { //SAFETY: all internal indices are guaranteed to // be valid let next = unsafe { self.inner.get_unchecked(index.0) }; if let NodeInternal::Join(child_index, _) = *next { self.current = Some(index); child_index } else { self.current = None; index } } NodeInternal::BinaryOperation(_, index, _) if !self.second => { self.second = true; index } NodeInternal::BinaryOperation(_, _, index) => { self.current = None; index } _ => { #![allow(clippy::unreachable)] unreachable!() } }, self.id, ) .into(), ) } else { None } } } /// Index into the node vector of [`Expression`]. #[repr(transparent)] #[derive(Eq, PartialEq, Copy, Clone, Hash)] #[cfg_attr(feature = "core-fmt", derive(Debug))] //NOTE: false positive, see https://github.com/CAD97/generativity/issues/13 #[allow(repr_transparent_non_zst_fields)] pub struct NodeIndex<'id>(usize, Id<'id>); impl<'id> From<(NodeIndexInternal, Id<'id>)> for NodeIndex<'id> { fn from((index, id): (NodeIndexInternal, Id<'id>)) -> Self { Self(index.0, id) } } /// Index into the node vector of [`Expression`]. /// /// This is only for internal use to avoid the presence of explicit lifetimes /// on [`Expression`] members. #[repr(transparent)] #[derive(Eq, PartialEq, Copy, Clone, Hash)] #[cfg_attr(feature = "core-fmt", derive(Debug))] pub(crate) struct NodeIndexInternal(usize); impl From> for NodeIndexInternal { fn from(NodeIndex(index, _): NodeIndex<'_>) -> Self { Self(index) } } /// Index into the string intern vector of [`Expression`]. #[repr(transparent)] #[derive(Eq, PartialEq, Copy, Clone, Hash)] #[cfg_attr(feature = "core-fmt", derive(Debug))] //NOTE: false positive, see https://github.com/CAD97/generativity/issues/13 #[allow(repr_transparent_non_zst_fields)] pub struct StringIndex<'id>(usize, Id<'id>); impl<'id> From<(StringIndexInternal, Id<'id>)> for StringIndex<'id> { fn from((index, id): (StringIndexInternal, Id<'id>)) -> Self { Self(index.0, id) } } /// Index into the string intern vector of [`Expression`]. /// /// This is only for internal use to avoid the presence of explicit lifetimes /// on [`Expression`] members. #[repr(transparent)] #[derive(Eq, PartialEq, Copy, Clone, Hash)] #[cfg_attr(feature = "core-fmt", derive(Debug))] pub(crate) struct StringIndexInternal(usize); impl From> for StringIndexInternal { fn from(StringIndex(index, _): StringIndex<'_>) -> Self { Self(index) } } /// Internal representation of an expression node. #[non_exhaustive] #[derive(PartialEq, Clone)] #[cfg_attr(feature = "core-fmt", derive(Debug))] pub(crate) enum NodeInternal { /// Operation over other nodes. /// /// The first component indicates the kind of operation and the second is /// the first argument. Several arguments may be represented through the /// use of `Join`. Operation(StringIndexInternal, Option), /// Binary operation. /// /// This is special cased due to how common they are. `Join`s are not /// permitted to be referenced by the indices here. BinaryOperation( StringIndexInternal, NodeIndexInternal, NodeIndexInternal, ), /// Joins one `NodeInternal` with another in a sequence, with the first /// coming before the second. /// /// The first `NodeIndexInternal` must not be a `Join`. Join(NodeIndexInternal, NodeIndexInternal), /// Scalar value. Scalar(ScalarValue), /// Variable. Variable(StringIndexInternal), } /// Node in an expression. #[non_exhaustive] #[derive(PartialEq, Clone)] #[cfg_attr(feature = "core-fmt", derive(Debug))] pub enum Node<'id> { /// Operation over other nodes. /// /// The first component indicates the kind of operation and the second is /// the first argument. Several arguments may be represented through the /// use of `Join`. Operation(StringIndex<'id>, Option>), /// Binary operation. /// /// This is special cased due to how common they are. `Join`s are not /// permitted to be referenced by the indices here. BinaryOperation(StringIndex<'id>, NodeIndex<'id>, NodeIndex<'id>), /// Joins one `NodeInternal` with another in a sequence, with the first /// coming before the second. /// /// The first `NodeIndexInternal` must not be a `Join`. Join(NodeIndex<'id>, NodeIndex<'id>), /// Scalar value. Scalar(ScalarValue), /// Variable. Variable(StringIndex<'id>), } impl<'id> From<(NodeInternal, Id<'id>)> for Node<'id> { fn from((node, id): (NodeInternal, Id<'id>)) -> Self { match node { NodeInternal::Operation(s, n) => { Node::Operation((s, id).into(), n.map(|n| (n, id).into())) } NodeInternal::BinaryOperation(s, n1, n2) => { Node::BinaryOperation( (s, id).into(), (n1, id).into(), (n2, id).into(), ) } NodeInternal::Join(n1, n2) => { Node::Join((n1, id).into(), (n2, id).into()) } NodeInternal::Scalar(s) => Node::Scalar(s), NodeInternal::Variable(s) => Node::Variable((s, id).into()), } } } /// Scalar value. #[non_exhaustive] #[derive(PartialEq, Clone)] #[cfg_attr(feature = "core-fmt", derive(Debug))] pub enum ScalarValue { /// Unsigned integer. UnsignedInteger(u64), /// Signed integer. Integer(i64), /// Floating point value. Float(f64), /// Rational number with unsigned integer components. UnsignedIntegerRational(Rational), /// Rational number with signed integer components. SignedIntegerRational(Rational), } #[cfg(feature = "core-fmt")] impl fmt::Display for ScalarValue { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Self::UnsignedInteger(v) => write!(f, "{}", v), Self::Integer(v) => write!(f, "{}", v), Self::Float(v) => write!(f, "{}", v), Self::UnsignedIntegerRational(v) => write!(f, "{}", v), Self::SignedIntegerRational(v) => write!(f, "{}", v), } } } /// Returns a permutation vector for the given list of [`Expression`] nodes /// and the number of reachable nodes. /// /// Applying this permutation vector to the node list will topologically sort /// it. Nodes without any incoming edges will be marked with `usize::MAX` in /// the permutation vector. The root node is ignored as it is not expected to /// have an incoming edge. This is acceptable as given our constraints, it is /// not possible for an index to be greater than `isize::MAX`. /// /// # Errors /// /// An error is returned if a cycle was found or if reserving space for the /// permutation vector failed. fn permutation_vector_of( alloc: A, root: NodeIndexInternal, vertices: &[NodeInternal], ) -> Result<(Vec, usize), Error> where A: Allocator + Clone, { let mut indegree = Vec::try_with_capacity_in(vertices.len(), alloc.clone())?; indegree.resize(vertices.len(), 0); //NOTE: used to queue nodes for reachability checks and Kahn's algorithm let mut worklist = Vec::try_with_capacity_in(1, alloc.clone())?; //SAFETY: we just reserved space for this element unsafe { worklist.push_within_capacity(root.0).unwrap_unchecked(); } //NOTE: maps indices to their index in a topologically sorted node // vector. indices with the value of `usize::MAX` are not // reachable and should be removed by garbage collection. during the // reachability check, nodes which are reachable are mapped to // themselves as this is harmless and can be used to skip seen nodes let mut permutation = Vec::try_with_capacity_in(vertices.len(), alloc)?; permutation.resize(vertices.len(), usize::MAX); let mut n_reachable = 0usize; while let Some(current) = worklist.pop() { //SAFETY: we have reserved enough space to track the permuted index // of all valid indices let permuted_index = unsafe { permutation.get_unchecked_mut(current) }; if *permuted_index == current { continue; } *permuted_index = current; //NOTE: this is okay because allocation sizes are limited to // `isize::MAX`; n_reachable.wrapping_increment_mut(); //SAFETY: all indices in the vertex list are valid #[allow(clippy::wildcard_enum_match_arm)] match unsafe { vertices.get_unchecked(current) } { NodeInternal::Operation(_, Some(NodeIndexInternal(i))) => { //SAFETY: we have reserved enough space to track the indegree // of all valid indices unsafe { indegree.get_unchecked_mut(*i) } .wrapping_increment_mut(); worklist.try_push(*i)?; } NodeInternal::Join( NodeIndexInternal(i), NodeIndexInternal(j), ) | NodeInternal::BinaryOperation( _, NodeIndexInternal(i), NodeIndexInternal(j), ) => { //SAFETY: ditto unsafe { indegree.get_unchecked_mut(*i) } .wrapping_increment_mut(); worklist.try_push(*i)?; //SAFETY: ditto unsafe { indegree.get_unchecked_mut(*j) } .wrapping_increment_mut(); worklist.try_push(*j)?; } _ => (), } } //NOTE: tracks the location of the cursor within the permuted vector and // serves as a count of the nodes which we have permuted let mut permuted_position = 0; //NOTE: double check there's no cycle that goes to root. //SAFETY: we have reserved enough space to track the indegree of all // valid indices if *unsafe { indegree.get_unchecked(root.0) } != 0 { return Err(Error::Cycle); } worklist.try_push(root.0)?; while let Some(current) = worklist.pop() { //SAFETY: we constructed the permutation vector so that indexing into // it with vertex indices is always okay *unsafe { permutation.get_unchecked_mut(current) } = permuted_position; permuted_position.wrapping_increment_mut(); //SAFETY: all indices in the vertex list are valid #[allow(clippy::wildcard_enum_match_arm)] match unsafe { vertices.get_unchecked(current) } { NodeInternal::Operation(_, Some(NodeIndexInternal(i))) => { //SAFETY: we have reserved enough space to track the indegree // of all valid indices let indegree_of_i = unsafe { indegree.get_unchecked_mut(*i) }; indegree_of_i.wrapping_decrement_mut(); if *indegree_of_i == 0 { worklist.try_push(*i)?; } } NodeInternal::Join( NodeIndexInternal(i), NodeIndexInternal(j), ) | NodeInternal::BinaryOperation( _, NodeIndexInternal(i), NodeIndexInternal(j), ) => { //SAFETY: ditto let indegree_of_i = unsafe { indegree.get_unchecked_mut(*i) }; indegree_of_i.wrapping_decrement_mut(); if *indegree_of_i == 0 { worklist.try_push(*i)?; } //SAFETY: ditto let indegree_of_j = unsafe { indegree.get_unchecked_mut(*j) }; indegree_of_j.wrapping_decrement_mut(); if *indegree_of_j == 0 { worklist.try_push(*j)?; } } _ => (), } } if n_reachable == permuted_position { Ok((permutation, n_reachable)) } else { Err(Error::Cycle) } } /// Representation of an error that occurred within [`Expression`]. #[non_exhaustive] #[derive(Clone, PartialEq, Eq)] #[cfg_attr(feature = "core-fmt", derive(Debug))] pub enum Error { /// Memory reservation error. TryReserveError(TryReserveError), /// Unhandleable integer overflow error. IntegerOverflowError, /// Invalid child node. /// /// This occurs when you attempt to pass an index that points to a /// [`Node::Join`] as a child of an operation. InvalidChild, /// A cycle was found in the [`Expression`]. /// /// This is impossible, but representable as an error state for testing /// purposes. #[doc(hidden)] Cycle, } impl From for Error { fn from(e: TryReserveError) -> Self { Self::TryReserveError(e) } } impl From for Error { fn from(_: IntegerOverflowError) -> Self { Self::IntegerOverflowError } } #[cfg(feature = "core-fmt")] impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let reason = match self { Error::TryReserveError(_) => "unable to allocate memory", Error::IntegerOverflowError => "integer overflow", Error::InvalidChild => "invalid child node", Error::Cycle => "expression contains a cycle", }; f.write_str(reason) } } #[cfg(feature = "core-error")] impl error::Error for Error {} #[cfg(all(test, feature = "core-error"))] mod test { #![allow(clippy::arithmetic_side_effects)] #![allow(clippy::indexing_slicing)] #![allow(clippy::unreachable)] #![allow(clippy::wildcard_enum_match_arm)] #![allow(clippy::expect_used)] use super::*; use generativity::make_guard; use proptest::prelude::*; use proptest_state_machine::{ ReferenceStateMachine, StateMachineTest, prop_state_machine, }; use std::{assert_matches::assert_matches, collections::HashMap}; #[test] fn simple() { make_guard!(g); let mut expr = Expression::new().open(g); let a = expr .insert_variable("apples") .expect("unable to insert a variable"); assert_eq!(expr.children_of(a).collect::>().as_slice(), &[]); let b = expr .insert_variable("bananas") .expect("unable to insert a variable"); let c = expr .insert_variable("oranges") .expect("unable to insert a variable"); let sin = expr .insert_operation("sin", [a]) .expect("unable to insert an operation"); assert_eq!( expr.children_of(sin).collect::>().as_slice(), &[a] ); let sum = expr .insert_operation("sum", [a, b, c]) .expect("unable to insert an operation"); assert_eq!( expr.children_of(sum).collect::>().as_slice(), &[a, b, c] ); let product = expr .insert_operation("*", [sin, sum]) .expect("unable to insert an operation"); assert_eq!( expr.children_of(product).collect::>().as_slice(), &[sin, sum] ); expr.set_root_node(product); let old_length = expr.len(); let expr = expr.close().sort().expect("unable to sort an expression"); assert_eq!(old_length, expr.len()); make_guard!(h); let mut expr = expr.open(h); expr.take_root_node(); let expr = expr.close().sort().expect("unable to sort an expression"); assert!(expr.is_empty()); } /// Reference state machine for [`Expression`]s. /// /// Models [`Expression`]s as a root node index, a vector of nodes, a /// mapping of parent node indices to a vector of child node indices, /// a vector of valid indices, and whether or not it is currently open. struct ExpressionReference; /// State used by the reference state machine. #[derive(Clone, Debug, Default)] struct ExpressionReferenceState { /// Index of the root node in the node vector. pub root: Option, /// Vector of nodes in the expression. pub nodes: Vec, /// Mapping of parent node indices to child node indices. pub children: HashMap>, /// The current list of valid indices. pub valid_indices: Vec, /// Whether or not the expression is open. /// /// If it is open, then additive operations and indexing is allowed. /// Otherwise, operations that may invalidate indices such as node /// removal, sorting, etc are allowed. pub is_open: bool, } /// Simplified [`NodeInternal`] representation for the [`HashMap`] backed /// graph. #[derive(Clone, Debug)] enum NodeRep { /// Operation over other nodes. Operation(String), /// Scalar value. Scalar(ScalarValue), /// Variable. Variable(String), } /// Possible transitions for the state machine. #[derive(Clone, Debug)] enum ExpressionTransition { /// Open the expression. Open, /// Add a new operation to the expression. NewOperation(String, Vec), /// Add a new scalar to the expression. NewScalar(ScalarValue), /// Add a new variable to the expression. NewVariable(String), /// Set the root node of the expression. SetRoot(usize), /// Remove the root of the expression. TakeRoot, /// Acquire a valid index to the root of expression. AcquireRoot, /// Acquire valid indices to the children of the given node. AcquireChildren(usize), /// Close the expression, invalidating all currently held indices. Close, /// Remove all nodes from the expression. Clear, } impl ExpressionTransition { /// Indicates if this transition requires that the expression be open. fn needs_open(&self) -> bool { matches!( self, ExpressionTransition::NewOperation(_, _) | ExpressionTransition::NewScalar(_) | ExpressionTransition::NewVariable(_) | ExpressionTransition::SetRoot(_) | ExpressionTransition::TakeRoot | ExpressionTransition::AcquireRoot | ExpressionTransition::AcquireChildren(_) | ExpressionTransition::Close ) } } impl ReferenceStateMachine for ExpressionReference { type State = ExpressionReferenceState; type Transition = ExpressionTransition; fn init_state() -> BoxedStrategy { //NOTE: we could randomly initialize state. i don't think this is // necessary to test this sufficiently, but if we ever do, we // also need to tweak the constructor of the [`Expression`] // to replicate one from the reference state Just(Self::State::default()).boxed() } fn transitions( state: &Self::State, ) -> BoxedStrategy { if state.is_open && state.valid_indices.is_empty() { prop_oneof![ 1 => Just(ExpressionTransition::Close), 1 => Just(ExpressionTransition::TakeRoot), 5 => Just(ExpressionTransition::AcquireRoot), //NOTE: what a scalar value actually is has zero // bearing on anything so we can just yield a static // value here without any problems. in the future we // may have something like a common subexpression // elimination feature, which would make varying it // valuable, but for now, we have no reason to do // anything other than adding a static scalar. this // is different for variables because string // interning exists, so varying those does actually // have an effect 7 => Just(ExpressionTransition::NewScalar( ScalarValue::UnsignedInteger(1) )), // 702 possible identifiers is probably enough 7 => "[a-z]{1,2}" .prop_map(ExpressionTransition::NewVariable), ] .boxed() } else if state.is_open { let n_valid = state.valid_indices.len(); prop_oneof![ 1 => Just(ExpressionTransition::Close), 1 => Just(ExpressionTransition::TakeRoot), 5 => Just(ExpressionTransition::AcquireRoot), //NOTE: what a scalar value actually is has zero // bearing on anything so we can just yield a static // value here without any problems. in the future we // may have something like a common subexpression // elimination feature, which would make varying it // valuable, but for now, we have no reason to do // anything other than adding a static scalar. this // is different for variables because string // interning exists, so varying those does actually // have an effect 7 => Just(ExpressionTransition::NewScalar( ScalarValue::UnsignedInteger(1) )), // 702 possible identifiers is probably enough 7 => "[a-z]{1,2}" .prop_map(ExpressionTransition::NewVariable), 9 => prop::sample::subsequence( state.valid_indices.clone(), 0..=n_valid ) //NOTE: 702 possible identifiers is // probably enough .prop_flat_map(|v| ("[A-Z]{1,2}", Just(v))) .prop_map(|(s, v)| ExpressionTransition::NewOperation(s, v) ), 1 => prop::sample::select(state.valid_indices.clone()) .prop_map(ExpressionTransition::SetRoot), 5 => prop::sample::select(state.valid_indices.clone()) .prop_map(ExpressionTransition::AcquireChildren), ] .boxed() } else { prop_oneof![ 1 => Just(ExpressionTransition::Clear), 9 => Just(ExpressionTransition::Open), ] .boxed() } } fn preconditions( state: &Self::State, transition: &Self::Transition, ) -> bool { if state.is_open ^ transition.needs_open() { return false; } match transition { ExpressionTransition::NewOperation(_, children) => { for c in children { if *c >= state.nodes.len() || !state.valid_indices.contains(c) { return false; } } } ExpressionTransition::SetRoot(node) | ExpressionTransition::AcquireChildren(node) => { if *node >= state.nodes.len() || !state.valid_indices.contains(node) { return false; } } _ => (), } true } fn apply( mut state: Self::State, transition: &Self::Transition, ) -> Self::State { match transition { ExpressionTransition::Open => state.is_open = true, ExpressionTransition::Clear => { state.nodes.clear(); state.children.clear(); state.root = None; } ExpressionTransition::Close => { state.valid_indices.clear(); state.is_open = false; } ExpressionTransition::SetRoot(index) => { state.root = Some(*index); } ExpressionTransition::TakeRoot => { state.root = None; } ExpressionTransition::NewVariable(s) => { state.nodes.push(NodeRep::Variable(s.clone())); let i = state.nodes.len() - 1; state.valid_indices.push(i); state.children.insert(i, Vec::new()); } ExpressionTransition::NewScalar(v) => { state.nodes.push(NodeRep::Scalar(v.clone())); let i = state.nodes.len() - 1; state.valid_indices.push(i); state.children.insert(i, Vec::new()); } ExpressionTransition::AcquireRoot => { if let Some(root) = state.root { state.valid_indices.push(root); } } ExpressionTransition::AcquireChildren(i) => { state.valid_indices.extend(&state.children[i]); } ExpressionTransition::NewOperation(s, c) => { state.nodes.push(NodeRep::Operation(s.clone())); let i = state.nodes.len() - 1; state.children.insert(i, c.clone()); state.valid_indices.push(i); } } state } } /// Wrapper around an [`Expression`] indicating if it is open or not. enum ExpressionWrapper<'id, A = Global> where A: Allocator + Clone, { /// Closed expression.. Closed(Expression), /// Open expression and a mapping from the state machine's indices to /// the [`Expression`]'s. Open(HashMap>, OpenExpression<'id, A>), } impl<'id> StateMachineTest for ExpressionWrapper<'id> { type SystemUnderTest = Self; type Reference = ExpressionReference; fn init_test( _ref_state: &::State, ) -> Self::SystemUnderTest { ExpressionWrapper::Closed(Expression::new()) } fn apply( mut state: Self::SystemUnderTest, ref_state: &::State, transition: ExpressionTransition, ) -> Self::SystemUnderTest { match transition { ExpressionTransition::Open => { if let ExpressionWrapper::Closed(closed_state) = state { //SAFETY: unfortunately, we have to do this, // otherwise this test cannot work out due to // lifetime issues let guard = unsafe { Guard::new(Id::new()) }; state = ExpressionWrapper::Open( HashMap::new(), closed_state.open(guard), ); } else { unreachable!(); } } ExpressionTransition::Clear => { if let ExpressionWrapper::Closed(closed_state) = &mut state { closed_state.clear(); //NOTE: post-conditions assert!( closed_state.inner.is_empty(), "inner node vector should be empty after clearing" ); assert!( closed_state.interns.is_empty(), "intern vector should be empty after clearing" ); assert!( closed_state.root.is_none(), "root node should not exist after clearing" ); } else { unreachable!(); } } ExpressionTransition::Close => { if let ExpressionWrapper::Open(_, open_state) = state { state = ExpressionWrapper::Closed(open_state.close()); } else { unreachable!(); } } ExpressionTransition::SetRoot(index) => { if let ExpressionWrapper::Open(mapping, open_state) = &mut state { open_state.set_root_node(mapping[&index]); //NOTE: post-conditions assert_eq!( open_state.0.root, Some(NodeIndexInternal(mapping[&index].0)), "root should be set to the given index after setting it" ); } else { unreachable!(); } } ExpressionTransition::TakeRoot => { if let ExpressionWrapper::Open(_, open_state) = &mut state { let root_before = open_state.0.root; let taken_root = open_state .take_root_node() .map(|r| NodeIndexInternal(r.0)); //NOTE: post-conditions assert!( open_state.0.root.is_none(), "root should be None after taking it" ); assert_eq!( root_before, taken_root, "root before taking and returned root value from taking should be equal" ); } else { unreachable!(); } } ExpressionTransition::NewVariable(s) => { if let ExpressionWrapper::Open(mapping, open_state) = &mut state { let var = open_state .insert_variable(&s) .expect("inserting a variable should not fail"); mapping.insert(ref_state.nodes.len() - 1, var); let position = open_state.0.interns.iter().position(|t| t == s.as_str()).expect("the interns list should contain the name of the inserted variable"); assert_matches!( open_state.0.inner[var.0], NodeInternal::Variable(position), "the variable node should point to the position of the equivalent string in the intern list" ); } else { unreachable!(); } } ExpressionTransition::NewScalar(v) => { if let ExpressionWrapper::Open(mapping, open_state) = &mut state { let scalar = open_state .insert_scalar(v) .expect("inserting a scalar should not fail"); mapping.insert(ref_state.nodes.len() - 1, scalar); assert_matches!( &open_state.0.inner[scalar.0], NodeInternal::Scalar(v), "the scalar node should contain the same value as was inserted" ); } else { unreachable!(); } } ExpressionTransition::NewOperation(s, c) => { if let ExpressionWrapper::Open(mapping, open_state) = &mut state { let operation = open_state .insert_operation( &s, c.iter().copied().map(|i| mapping[&i]), ) .expect("inserting an operation should not fail"); mapping.insert(ref_state.nodes.len() - 1, operation); let position = open_state.0.interns.iter().position(|t| t == s.as_str()).expect("the interns list should contain the name of the inserted operation"); if c.len() == 2 { assert_matches!( open_state.0.inner[operation.0], NodeInternal::BinaryOperation(position, _, _), "the operation node should point to the position of the equivalent string in the intern list", ); } else { assert_matches!( open_state.0.inner[operation.0], NodeInternal::Operation(position, _), "the operation node should point to the position of the equivalent string in the intern list", ); } //TODO: validate join nodes, child node equality } else { unreachable!(); } } ExpressionTransition::AcquireRoot => { if let ExpressionWrapper::Open(mapping, open_state) = &mut state { let root = open_state.root_node(); assert_eq!( ref_state.root.is_some(), root.is_some(), "both the reference state and system under test must match" ); if let Some(root) = root && let Some(ref_root) = ref_state.root { mapping.insert(ref_root, root); } } else { unreachable!(); } } ExpressionTransition::AcquireChildren(i) => { if let ExpressionWrapper::Open(mapping, open_state) = &mut state { mapping.extend( ref_state.children[&i] .iter() .copied() .zip(open_state.children_of(mapping[&i])), ); } else { unreachable!(); } } } state } fn check_invariants( state: &Self::SystemUnderTest, ref_state: &::State, ) { let expr = match state { ExpressionWrapper::Open(_, expr) => &expr.0, ExpressionWrapper::Closed(expr) => expr, }; let mut deduplicated_interns = expr.interns.clone(); deduplicated_interns.sort(); deduplicated_interns.dedup(); assert!( deduplicated_interns.len() == expr.interns.len(), "there must be no duplicated string interns" ); for node in &expr.inner { match node { NodeInternal::Operation(i, Some(n)) => { assert!( i.0 < expr.interns.len(), "intern index must be valid" ); assert!( n.0 < expr.inner.len(), "node index must be valid" ); } NodeInternal::BinaryOperation(i, n1, n2) => { assert!( i.0 < expr.interns.len(), "intern index must be valid" ); assert!( n1.0 < expr.inner.len(), "node index must be valid" ); assert!( n2.0 < expr.inner.len(), "node index must be valid" ); assert!( !matches!( expr.inner[n1.0], NodeInternal::Join(_, _) ), "binary operations must not reference joins" ); assert!( !matches!( expr.inner[n2.0], NodeInternal::Join(_, _) ), "binary operations must not reference joins" ); } NodeInternal::Join(n1, n2) => { assert!( n1.0 < expr.inner.len(), "node index must be valid" ); assert!( !matches!( expr.inner[n1.0], NodeInternal::Join(_, _) ), "operations must not have joins as children" ); assert!( n2.0 < expr.inner.len(), "node index must be valid" ); } NodeInternal::Variable(i) => { assert!( i.0 < expr.interns.len(), "intern index must be valid" ); } _ => (), } } //NOTE: ensure string conversion doesn't panic. let _ = expr.to_string(); if let Some(root) = expr.root { permutation_vector_of(Global, root, expr.inner.as_slice()) .expect("expression must not contain a cycle"); } //TODO: validate the layout of the entire expression against the // reference. } } prop_state_machine! { #![proptest_config(ProptestConfig { // allow more rejects to make shrinking results better max_global_rejects: u32::MAX, // allow more shrinking iterations to improve the results max_shrink_iters: u32::MAX, // disable failure persistence so miri works #[cfg(miri)] failure_persistence: None, ..ProptestConfig::default() })] #[test] fn matches_state_machine(sequential 1..=250 => ExpressionWrapper); } } #[cfg(all(kani, feature = "core-error"))] mod verification { use super::*; use generativity::make_guard; #[kani::proof] fn simple() { //TODO: add more of a proof here } }