//! Common representation of mathematical expressions. use alloc::{ alloc::{Allocator, Global}, collections::TryReserveError, string::String, vec, vec::Vec, }; use core::{hint, iter::TrustedLen}; use generativity::{Guard, Id}; use crate::numerics::rational::Rational; #[cfg(feature = "core-fmt")] use core::fmt; #[cfg(feature = "core-error")] use core::error; //TODO: consider topological sorting of the expression's internals as this // may be better for performance. //TODO: consider garbage collection of unreachable nodes (and their // children). //TODO: if we add garbage collection, support marking a node for deletion // when garbage collection is performed. //TODO: consider renaming `OpenExpression` to something else. //TODO: add more thorough documentation, along the lines of the standard // library's `Vec`, given how this is a very fundamental type. /// Representation of a mathematical expression. /// /// Components of the expression are maintained in a vector where they are /// linked to each other by their indices. Indices always point to another /// existing node. /// /// # Guarantees /// /// - The expression contains no cycles. /// - All references to other nodes are valid. /// - All references to interned strings are valid. /// - No strings are interned more than once. #[derive(Clone)] #[cfg_attr(feature = "core-fmt", derive(Debug))] pub struct Expression where A: Allocator + Clone, { /// List of nodes linked to each other by their indices. pub(crate) inner: Vec, /// List of interned strings. pub(crate) interns: Vec, /// Root node of the expression. pub(crate) root: Option, } impl Expression { /// Constructs an empty expression. pub fn new() -> Self { Self::new_in(Global) } } impl Expression where A: Allocator + Clone, { /// Constructs an empty expression using the provided [`Allocator`]. pub fn new_in(alloc: A) -> Self { Self { inner: Vec::new_in(alloc.clone()), interns: Vec::new_in(alloc), root: None, } } /// Clears the expression, removing all nodes and associated information. pub fn clear(&mut self) { self.inner.clear(); self.interns.clear(); self.root = None; } /// Opens up the expression for additive operations and indexing. /// /// While in this state, no destructive operations may be done, such as /// garbage collection and sorting of the node list. pub fn open<'id>(self, guard: Guard<'id>) -> OpenExpression<'id, A> { OpenExpression(self, guard.into()) } /// Inserts a [`NodeInternal`] into the node list and returns the index. /// /// # Errors /// /// Returns an error if storage was unable to be reserved for the new /// node. fn insert_node( &mut self, node: NodeInternal, ) -> Result { self.inner.try_reserve(1)?; //SAFETY: we just reserved space for this element unsafe { self.inner.push_within_capacity(node).unwrap_unchecked() } Ok(NodeIndexInternal(self.inner.len() - 1)) } /// Inserts a [`&str`] into the intern list if it doesn't already exist /// and returns the index. /// /// # Errors /// /// Returns an error if storage was unable to be reserved for the new /// string. fn insert_string( &mut self, string: impl AsRef, ) -> Result { if let Some(i) = self.interns.iter().position(|s| s == string.as_ref()) { return Ok(StringIndexInternal(i)); } self.interns.try_reserve(1)?; //SAFETY: we just reserved space for this element unsafe { self.interns .push_within_capacity(string.as_ref().into()) .unwrap_unchecked() } Ok(StringIndexInternal(self.interns.len() - 1)) } /// Returns the number of nodes within this expression. pub fn len(&self) -> usize { self.inner.len() } /// Indicates if the expression has no nodes. pub fn is_empty(&self) -> bool { self.inner.is_empty() } } impl Default for Expression { fn default() -> Self { Self::new() } } #[cfg(feature = "core-fmt")] impl fmt::Display for Expression where A: Allocator + Clone, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { /// State we are in when we are at a node. #[derive(PartialEq, Eq)] enum State { /// First time encountering a node. Enter, /// Second time encountering a node. Visit, /// Third time encountering a node. Close, } if let Some(root_index) = self.root { let mut stack = vec![(root_index, State::Enter)]; while let Some((node_index, state)) = stack.pop() { //SAFETY: all internal indices are guaranteed to be valid match *unsafe { self.inner.get_unchecked(node_index.0) } { NodeInternal::BinaryOperation(_, _, _) if state == State::Close => { f.write_str(")")?; } NodeInternal::BinaryOperation( StringIndexInternal(s), _, _, ) if state == State::Visit => { //SAFETY: all internal indices are guaranteed to be // valid write!(f, " {} ", unsafe { self.interns.get_unchecked(s) })?; } NodeInternal::BinaryOperation(_, c1, c2) if state == State::Enter => { f.write_str("(")?; stack.push((node_index, State::Close)); stack.push((c2, State::Enter)); stack.push((node_index, State::Visit)); stack.push((c1, State::Enter)); } NodeInternal::Operation(StringIndexInternal(s), None) => { //SAFETY: all internal indices are guaranteed to be // valid write!(f, "{}()", unsafe { self.interns.get_unchecked(s) })?; } NodeInternal::Operation(_, _) if state == State::Close => { f.write_str(")")?; } NodeInternal::Operation( StringIndexInternal(s), Some(c1), ) => { //SAFETY: all internal indices are guaranteed to be // valid write!(f, "{}(", unsafe { self.interns.get_unchecked(s) })?; stack.push((node_index, State::Close)); stack.push((c1, State::Enter)); } NodeInternal::Join(_, _) if state == State::Visit => { f.write_str(", ")?; } NodeInternal::Join(c1, c2) => { stack.push((c2, State::Enter)); stack.push((node_index, State::Visit)); stack.push((c1, State::Enter)); } NodeInternal::Scalar(ref v) => { write!(f, "{}", v)?; } NodeInternal::Variable(StringIndexInternal(s)) => { //SAFETY: all internal indices are guaranteed to be // valid write!(f, "{}", unsafe { self.interns.get_unchecked(s) })?; } _ => unreachable!(), } } } Ok(()) } } /// Representation of a mathematical expression. /// /// This is a variant of [`Expression`] in a state where it is able to be /// written to and indexed. #[derive(Clone)] #[cfg_attr(feature = "core-fmt", derive(Debug))] pub struct OpenExpression<'id, A>(Expression, Id<'id>) where A: Allocator + Clone; impl<'id, A> OpenExpression<'id, A> where A: Allocator + Clone, { /// Adds a new scalar value to the expression. /// /// # Errors /// /// Returns an error if storage was unable to be reserved for the new /// scalar value. pub fn insert_scalar( &mut self, value: impl Into, ) -> Result, Error> { Ok(( self.0.insert_node(NodeInternal::Scalar(value.into()))?, self.1, ) .into()) } /// Adds a new variable to the expression. /// /// # Errors /// /// Returns an error if storage was unable to be reserved for the new /// variable. pub fn insert_variable( &mut self, name: impl AsRef, ) -> Result, Error> { let name = self.0.insert_string(name)?; Ok( (self.0.insert_node(NodeInternal::Variable(name))?, self.1) .into(), ) } /// Adds a new operation to the expression. /// /// # Errors /// /// Returns an error if storage was unable to be reserved for the new /// operation or if one of the children of the node would have been a /// [`Node::Join`]. pub fn insert_operation( &mut self, operator: impl AsRef, arguments: Arguments, ) -> Result, Error> where Arguments: IntoIterator>, Arguments::IntoIter: TrustedLen, { let operator = self.0.insert_string(operator)?; let mut arguments = arguments.into_iter(); let n_arguments = arguments.size_hint().0; if n_arguments == 2 { //SAFETY: the iterator is required to implement `TrustedLen` let first = unsafe { arguments.next().unwrap_unchecked() }; //SAFETY: the iterator is required to implement `TrustedLen` let second = unsafe { arguments.next().unwrap_unchecked() }; if matches!( //SAFETY: lifetime branding ensures this index is valid unsafe { self.0.inner.get_unchecked(first.0) }, NodeInternal::Join(_, _) ) || matches!( //SAFETY: lifetime branding ensures this index is valid unsafe { self.0.inner.get_unchecked(second.0) }, NodeInternal::Join(_, _) ) { return Err(Error::InvalidChild); } Ok(( self.0.insert_node(NodeInternal::BinaryOperation( operator, first.into(), second.into(), ))?, self.1, ) .into()) } else { let number_of_joins = n_arguments.saturating_sub(1); self.0.inner.try_reserve(number_of_joins + 1)?; //SAFETY: the iterator is required to implement `TrustedLen` so // we will have reserved enough space for these in addition // to guaranteeing the noted elements exist // // additionally, because of our lifetime branding, all // indices referenced by lifetime-branded node indices are // valid unsafe { self.0 .inner .push_within_capacity(match n_arguments { 0 => NodeInternal::Operation(operator, None), 1 => { let argument = arguments.next().unwrap_unchecked(); if matches!( self.0.inner.get_unchecked(argument.0), NodeInternal::Join(_, _) ) { return Err(Error::InvalidChild); } NodeInternal::Operation( operator, Some(argument.into()), ) } //NOTE: we know that we will be adding at least one // more slot to hold an `NodeInternal::Join` _ => NodeInternal::Operation( operator, Some(NodeIndexInternal(self.0.inner.len() + 1)), ), }) .unwrap_unchecked() } let operation_index = NodeIndexInternal(self.0.inner.len() - 1); //NOTE: we are exclusively inserting `NodeInternal::Join`s now for (i, argument) in arguments.enumerate() { if matches!( //SAFETY: lifetime branding ensures this index is valid unsafe { self.0.inner.get_unchecked(argument.0) }, NodeInternal::Join(_, _) ) { //NOTE: leaving the inner vector in this state is not // invalid as nothing will be able to reference any // of the slots allocated already // // when it is added, garbage collection will clean up // after this if it happens return Err(Error::InvalidChild); } if i != n_arguments - 1 { //SAFETY: we know that we have reserved enough space for // each join by this point unsafe { self.0 .inner .push_within_capacity(NodeInternal::Join( argument.into(), NodeIndexInternal(self.0.inner.len() + 1), )) .unwrap_unchecked() } } else { let last_index = self.0.inner.len() - 1; if let &mut NodeInternal::Join(_, ref mut second) = //SAFETY: we know that if we are at this point, // the current last element exists unsafe { self.0.inner.get_unchecked_mut(last_index) } { *second = argument.into(); } else { //SAFETY: we know by this point that this last // element must be `NodeInternal::Join` unsafe { hint::unreachable_unchecked() } } } } Ok((operation_index, self.1).into()) } } /// Sets the root node of the expression. pub fn set_root_node(&mut self, node: NodeIndex<'id>) { self.0.root = Some(node.into()); } /// Closes the expression to additive operations and being indexed. pub fn close(self) -> Expression { self.0 } /// Returns the number of nodes within this expression. pub fn len(&self) -> usize { self.0.inner.len() } /// Indicates if the expression has no nodes. pub fn is_empty(&self) -> bool { self.0.inner.is_empty() } /// Gets the root node of the expression, if it exists. pub fn root_node(&self) -> Option> { Some((self.0.root?, self.1).into()) } /// Takes the expression's root node. pub fn take_root_node(&mut self) -> Option> { Some((self.0.root.take()?, self.1).into()) } /// Retrieves the indicated node from the expression. pub fn node(&self, index: NodeIndex<'id>) -> Node<'id> { //SAFETY: by the lifetime, this is guaranteed to be a valid index ( unsafe { self.0.inner.get_unchecked(index.0) }.clone(), self.1, ) .into() } /// Retrieves the indicated string from the expression. pub fn string(&self, index: StringIndex<'id>) -> &str { //SAFETY: by the lifetime, this is guaranteed to be a valid index unsafe { self.0.interns.get_unchecked(index.0) } } /// Returns an iterator over the children of the node. pub fn children_of<'a>( &'a self, index: NodeIndex<'id>, ) -> Children<'a, 'id> { Children::new(&self.0.inner, index) } } #[cfg(feature = "core-fmt")] impl fmt::Display for OpenExpression<'_, A> where A: Allocator + Clone, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{}", self.0) } } /// Iterator over the children of a [`Node`]. pub struct Children<'a, 'id> { /// The node we are currently at. /// /// If the node given to the constructor was a [`Node::Scalar`], /// [`Node::Variable`], or a [`Node::Operation`] with `None` as the /// second parameter, this must be `None`, and hence this may never be /// referencing a node of those kinds. current: Option, /// Whether or not we are at the "second" part of the node. /// /// This is only relevant for [`Node::BinaryOperation`] as it is not /// clear which of the children are being referenced. second: bool, /// Reference to the internal storage of the [`Expression`]. inner: &'a [NodeInternal], /// Lifetime tying it to the open state of the [`Expression`]. id: Id<'id>, } impl<'a, 'id> Children<'a, 'id> { /// Creates a new iterator over the children of a [`Node`]. fn new(inner: &'a [NodeInternal], index: NodeIndex<'id>) -> Self { let id = index.1; if matches!( //SAFETY: by the lifetime, this is guaranteed to be a valid // index unsafe { inner.get_unchecked(index.0) }, NodeInternal::Scalar(_) | NodeInternal::Variable(_) | NodeInternal::Operation(_, None) ) { Self { current: None, second: false, inner, id, } } else { Self { current: Some(index.into()), second: false, inner, id, } } } } impl<'a, 'id> Iterator for Children<'a, 'id> { type Item = NodeIndex<'id>; fn next(&mut self) -> Option { if let Some(current) = self.current { //SAFETY: by the lifetime, this is guaranteed to be a valid index let current = unsafe { self.inner.get_unchecked(current.0) }; Some( ( match *current { NodeInternal::Operation(_, Some(index)) => { //SAFETY: all internal indices are guaranteed to // be valid let next = unsafe { self.inner.get_unchecked(index.0) }; if let NodeInternal::Join(child_index, _) = *next { self.current = Some(index); child_index } else { self.current = None; index } } NodeInternal::Join(_, index) => { //SAFETY: all internal indices are guaranteed to // be valid let next = unsafe { self.inner.get_unchecked(index.0) }; if let NodeInternal::Join(child_index, _) = *next { self.current = Some(index); child_index } else { self.current = None; index } } NodeInternal::BinaryOperation(_, index, _) if !self.second => { self.second = true; index } NodeInternal::BinaryOperation(_, _, index) => { self.current = None; index } _ => unreachable!(), }, self.id, ) .into(), ) } else { None } } } /// Index into the node vector of [`Expression`]. #[repr(transparent)] #[derive(Eq, PartialEq, Copy, Clone, Hash)] #[cfg_attr(feature = "core-fmt", derive(Debug))] //NOTE: false positive, see https://github.com/CAD97/generativity/issues/13 #[allow(repr_transparent_external_private_fields)] pub struct NodeIndex<'id>(usize, Id<'id>); impl<'id> From<(NodeIndexInternal, Id<'id>)> for NodeIndex<'id> { fn from((index, id): (NodeIndexInternal, Id<'id>)) -> Self { Self(index.0, id) } } /// Index into the node vector of [`Expression`]. /// /// This is only for internal use to avoid the presence of explicit lifetimes /// on [`Expression`] members. #[repr(transparent)] #[derive(Eq, PartialEq, Copy, Clone, Hash)] #[cfg_attr(feature = "core-fmt", derive(Debug))] pub(crate) struct NodeIndexInternal(usize); impl From> for NodeIndexInternal { fn from(NodeIndex(index, _): NodeIndex<'_>) -> Self { Self(index) } } /// Index into the string intern vector of [`Expression`]. #[repr(transparent)] #[derive(Eq, PartialEq, Copy, Clone, Hash)] #[cfg_attr(feature = "core-fmt", derive(Debug))] //NOTE: false positive, see https://github.com/CAD97/generativity/issues/13 #[allow(repr_transparent_external_private_fields)] pub struct StringIndex<'id>(usize, Id<'id>); impl<'id> From<(StringIndexInternal, Id<'id>)> for StringIndex<'id> { fn from((index, id): (StringIndexInternal, Id<'id>)) -> Self { Self(index.0, id) } } /// Index into the string intern vector of [`Expression`]. /// /// This is only for internal use to avoid the presence of explicit lifetimes /// on [`Expression`] members. #[repr(transparent)] #[derive(Eq, PartialEq, Copy, Clone, Hash)] #[cfg_attr(feature = "core-fmt", derive(Debug))] pub(crate) struct StringIndexInternal(usize); impl From> for StringIndexInternal { fn from(StringIndex(index, _): StringIndex<'_>) -> Self { Self(index) } } /// Internal representation of an expression node. #[non_exhaustive] #[derive(PartialEq, Clone)] #[cfg_attr(feature = "core-fmt", derive(Debug))] pub(crate) enum NodeInternal { /// Operation over other nodes. /// /// The first component indicates the kind of operation and the second is /// the first argument. Several arguments may be represented through the /// use of `Join`. Operation(StringIndexInternal, Option), /// Binary operation. /// /// This is special cased due to how common they are. `Join`s are not /// permitted to be referenced by the indices here. BinaryOperation( StringIndexInternal, NodeIndexInternal, NodeIndexInternal, ), /// Joins one `NodeInternal` with another in a sequence, with the first /// coming before the second. /// /// The first `NodeIndexInternal` must not be a `Join`. Join(NodeIndexInternal, NodeIndexInternal), /// Scalar value. Scalar(ScalarValue), /// Variable. Variable(StringIndexInternal), } /// Node in an expression. #[non_exhaustive] #[derive(PartialEq, Clone)] #[cfg_attr(feature = "core-fmt", derive(Debug))] pub enum Node<'id> { /// Operation over other nodes. /// /// The first component indicates the kind of operation and the second is /// the first argument. Several arguments may be represented through the /// use of `Join`. Operation(StringIndex<'id>, Option>), /// Binary operation. /// /// This is special cased due to how common they are. `Join`s are not /// permitted to be referenced by the indices here. BinaryOperation(StringIndex<'id>, NodeIndex<'id>, NodeIndex<'id>), /// Joins one `NodeInternal` with another in a sequence, with the first /// coming before the second. /// /// The first `NodeIndexInternal` must not be a `Join`. Join(NodeIndex<'id>, NodeIndex<'id>), /// Scalar value. Scalar(ScalarValue), /// Variable. Variable(StringIndex<'id>), } impl<'id> From<(NodeInternal, Id<'id>)> for Node<'id> { fn from((node, id): (NodeInternal, Id<'id>)) -> Self { match node { NodeInternal::Operation(s, n) => { Node::Operation((s, id).into(), n.map(|n| (n, id).into())) } NodeInternal::BinaryOperation(s, n1, n2) => { Node::BinaryOperation( (s, id).into(), (n1, id).into(), (n2, id).into(), ) } NodeInternal::Join(n1, n2) => { Node::Join((n1, id).into(), (n2, id).into()) } NodeInternal::Scalar(s) => Node::Scalar(s), NodeInternal::Variable(s) => Node::Variable((s, id).into()), } } } /// Scalar value. #[non_exhaustive] #[derive(PartialEq, Clone)] #[cfg_attr(feature = "core-fmt", derive(Debug))] pub enum ScalarValue { /// Unsigned integer. UnsignedInteger(u64), /// Signed integer. Integer(i64), /// Floating point value. Float(f64), /// Rational number with unsigned integer components. UnsignedIntegerRational(Rational), /// Rational number with signed integer components. SignedIntegerRational(Rational), } #[cfg(feature = "core-fmt")] impl fmt::Display for ScalarValue { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Self::UnsignedInteger(v) => write!(f, "{}", v), Self::Integer(v) => write!(f, "{}", v), Self::Float(v) => write!(f, "{}", v), Self::UnsignedIntegerRational(v) => write!(f, "{}", v), Self::SignedIntegerRational(v) => write!(f, "{}", v), } } } /// Representation of an error that occurred within [`Expression`]. #[non_exhaustive] #[derive(Eq, PartialEq)] #[cfg_attr(feature = "core-fmt", derive(Debug))] pub enum Error { /// Memory reservation error. TryReserveError(TryReserveError), /// Invalid child node. /// /// This occurs when you attempt to pass an index that points to a /// [`Node::Join`] as a child of an operation. InvalidChild, } impl From for Error { fn from(e: TryReserveError) -> Self { Self::TryReserveError(e) } } #[cfg(feature = "core-fmt")] impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let reason = match self { Error::TryReserveError(_) => "unable to allocate memory", Error::InvalidChild => "invalid child node", }; f.write_str(reason) } } #[cfg(feature = "core-error")] impl error::Error for Error {} #[cfg(all(test, feature = "core-error"))] mod test { #![allow(clippy::expect_used)] use super::*; use generativity::make_guard; use proptest::prelude::*; use proptest_state_machine::{ ReferenceStateMachine, StateMachineTest, prop_state_machine, }; use std::{assert_matches::assert_matches, collections::HashMap}; #[test] fn simple() { make_guard!(g); let mut expr = Expression::new().open(g); let a = expr .insert_variable("apples") .expect("unable to insert a variable"); assert_eq!(expr.children_of(a).collect::>().as_slice(), &[]); let b = expr .insert_variable("bananas") .expect("unable to insert a variable"); let c = expr .insert_variable("oranges") .expect("unable to insert a variable"); let sin = expr .insert_operation("sin", [a]) .expect("unable to insert an operation"); assert_eq!( expr.children_of(sin).collect::>().as_slice(), &[a] ); let sum = expr .insert_operation("sum", [a, b, c]) .expect("unable to insert an operation"); assert_eq!( expr.children_of(sum).collect::>().as_slice(), &[a, b, c] ); let product = expr .insert_operation("*", [sin, sum]) .expect("unable to insert an operation"); assert_eq!( expr.children_of(product).collect::>().as_slice(), &[sin, sum] ); expr.set_root_node(product); let expr = expr.close(); println!("{}", expr); } /// Reference state machine for [`Expression`]s. /// /// Models [`Expression`]s as a root node index, a vector of nodes, a /// mapping of parent node indices to a vector of child node indices, /// a vector of valid indices, and whether or not it is currently open. struct ExpressionReference; /// State used by the reference state machine. #[derive(Clone, Debug, Default)] struct ExpressionReferenceState { /// Index of the root node in the node vector. pub root: Option, /// Vector of nodes in the expression. pub nodes: Vec, /// Mapping of parent node indices to child node indices. pub children: HashMap>, /// The current list of valid indices. pub valid_indices: Vec, /// Whether or not the expression is open. /// /// If it is open, then additive operations and indexing is allowed. /// Otherwise, operations that may invalidate indices such as node /// removal, sorting, etc are allowed. pub is_open: bool, } /// Simplified [`NodeInternal`] representation for the [`HashMap`] backed /// graph. #[derive(Clone, Debug)] enum NodeRep { /// Operation over other nodes. Operation(String), /// Scalar value. Scalar(ScalarValue), /// Variable. Variable(String), } /// Possible transitions for the state machine. #[derive(Clone, Debug)] enum ExpressionTransition { /// Open the expression. Open, /// Add a new operation to the expression. NewOperation(String, Vec), /// Add a new scalar to the expression. NewScalar(ScalarValue), /// Add a new variable to the expression. NewVariable(String), /// Set the root node of the expression. SetRoot(usize), /// Remove the root of the expression. TakeRoot, /// Acquire a valid index to the root of expression. AcquireRoot, /// Acquire valid indices to the children of the given node. AcquireChildren(usize), /// Close the expression, invalidating all currently held indices. Close, /// Remove all nodes from the expression. Clear, } impl ExpressionTransition { /// Indicates if this transition requires that the expression be open. fn needs_open(&self) -> bool { matches!( self, ExpressionTransition::NewOperation(_, _) | ExpressionTransition::NewScalar(_) | ExpressionTransition::NewVariable(_) | ExpressionTransition::SetRoot(_) | ExpressionTransition::TakeRoot | ExpressionTransition::AcquireRoot | ExpressionTransition::AcquireChildren(_) | ExpressionTransition::Close ) } } impl ReferenceStateMachine for ExpressionReference { type State = ExpressionReferenceState; type Transition = ExpressionTransition; fn init_state() -> BoxedStrategy { //TODO: randomly initialize state Just(Self::State::default()).boxed() } fn transitions( state: &Self::State, ) -> BoxedStrategy { if state.is_open && state.valid_indices.is_empty() { prop_oneof![ 1 => Just(ExpressionTransition::Close), 1 => Just(ExpressionTransition::TakeRoot), 5 => Just(ExpressionTransition::AcquireRoot), //NOTE: what a scalar value actually is has zero // bearing on anything so we can just yield a static // value here without any problems. in the future we // may have something like a common subexpression // elimination feature, which would make varying it // valuable, but for now, we have no reason to do // anything other than adding a static scalar. this // is different for variables because string // interning exists, so varying those does actually // have an effect 7 => Just(ExpressionTransition::NewScalar( ScalarValue::UnsignedInteger(1) )), // 702 possible identifiers is probably enough 7 => "[a-z]{1,2}" .prop_map(ExpressionTransition::NewVariable), ] .boxed() } else if state.is_open { let n_valid = state.valid_indices.len(); prop_oneof![ 1 => Just(ExpressionTransition::Close), 1 => Just(ExpressionTransition::TakeRoot), 5 => Just(ExpressionTransition::AcquireRoot), //NOTE: what a scalar value actually is has zero // bearing on anything so we can just yield a static // value here without any problems. in the future we // may have something like a common subexpression // elimination feature, which would make varying it // valuable, but for now, we have no reason to do // anything other than adding a static scalar. this // is different for variables because string // interning exists, so varying those does actually // have an effect 7 => Just(ExpressionTransition::NewScalar( ScalarValue::UnsignedInteger(1) )), // 702 possible identifiers is probably enough 7 => "[a-z]{1,2}" .prop_map(ExpressionTransition::NewVariable), 9 => prop::sample::subsequence( state.valid_indices.clone(), 0..=n_valid ) //NOTE: 702 possible identifiers is // probably enough .prop_flat_map(|v| ("[A-Z]{1,2}", Just(v))) .prop_map(|(s, v)| ExpressionTransition::NewOperation(s, v) ), 1 => prop::sample::select(state.valid_indices.clone()) .prop_map(ExpressionTransition::SetRoot), 5 => prop::sample::select(state.valid_indices.clone()) .prop_map(ExpressionTransition::AcquireChildren), ] .boxed() } else { prop_oneof![ 1 => Just(ExpressionTransition::Clear), 9 => Just(ExpressionTransition::Open), ] .boxed() } } fn preconditions( state: &Self::State, transition: &Self::Transition, ) -> bool { if state.is_open ^ transition.needs_open() { return false; } match transition { ExpressionTransition::NewOperation(_, children) => { for c in children { if *c >= state.nodes.len() || !state.valid_indices.contains(c) { return false; } } } ExpressionTransition::SetRoot(node) | ExpressionTransition::AcquireChildren(node) => { if *node >= state.nodes.len() || !state.valid_indices.contains(node) { return false; } } _ => (), } true } fn apply( mut state: Self::State, transition: &Self::Transition, ) -> Self::State { match transition { ExpressionTransition::Open => state.is_open = true, ExpressionTransition::Clear => { state.nodes.clear(); state.children.clear(); state.root = None; } ExpressionTransition::Close => { state.valid_indices.clear(); state.is_open = false; } ExpressionTransition::SetRoot(index) => { state.root = Some(*index); } ExpressionTransition::TakeRoot => { state.root = None; } ExpressionTransition::NewVariable(s) => { state.nodes.push(NodeRep::Variable(s.clone())); let i = state.nodes.len() - 1; state.valid_indices.push(i); state.children.insert(i, Vec::new()); } ExpressionTransition::NewScalar(v) => { state.nodes.push(NodeRep::Scalar(v.clone())); let i = state.nodes.len() - 1; state.valid_indices.push(i); state.children.insert(i, Vec::new()); } ExpressionTransition::AcquireRoot => { if let Some(root) = state.root { state.valid_indices.push(root); } } ExpressionTransition::AcquireChildren(i) => { state.valid_indices.extend(&state.children[i]); } ExpressionTransition::NewOperation(s, c) => { state.nodes.push(NodeRep::Operation(s.clone())); let i = state.nodes.len() - 1; state.children.insert(i, c.clone()); state.valid_indices.push(i); } } state } } /// Wrapper around an [`Expression`] indicating if it is open or not. enum ExpressionWrapper<'id, A = Global> where A: Allocator + Clone, { /// Closed expression.. Closed(Expression), /// Open expression and a mapping from the state machine's indices to /// the [`Expression`]'s. Open(HashMap>, OpenExpression<'id, A>), } impl<'id> StateMachineTest for ExpressionWrapper<'id> { type SystemUnderTest = Self; type Reference = ExpressionReference; fn init_test( _ref_state: &::State, ) -> Self::SystemUnderTest { //TODO: randomly initialize state in the reference and replicate // it over here. ExpressionWrapper::Closed(Expression::new()) } fn apply( mut state: Self::SystemUnderTest, ref_state: &::State, transition: ExpressionTransition, ) -> Self::SystemUnderTest { match transition { ExpressionTransition::Open => { if let ExpressionWrapper::Closed(closed_state) = state { //SAFETY: unfortunately, we have to do this, // otherwise this test cannot work out due to // lifetime issues let guard = unsafe { Guard::new(Id::new()) }; state = ExpressionWrapper::Open( HashMap::new(), closed_state.open(guard), ); } else { unreachable!(); } } ExpressionTransition::Clear => { if let ExpressionWrapper::Closed(closed_state) = &mut state { closed_state.clear(); //NOTE: post-conditions assert!( closed_state.inner.is_empty(), "inner node vector should be empty after clearing" ); assert!( closed_state.interns.is_empty(), "intern vector should be empty after clearing" ); assert!( closed_state.root.is_none(), "root node should not exist after clearing" ); } else { unreachable!(); } } ExpressionTransition::Close => { if let ExpressionWrapper::Open(_, open_state) = state { state = ExpressionWrapper::Closed(open_state.close()); } else { unreachable!(); } } ExpressionTransition::SetRoot(index) => { if let ExpressionWrapper::Open(mapping, open_state) = &mut state { open_state.set_root_node(mapping[&index]); //NOTE: post-conditions assert_eq!( open_state.0.root, Some(NodeIndexInternal(mapping[&index].0)), "root should be set to the given index after setting it" ); } else { unreachable!(); } } ExpressionTransition::TakeRoot => { if let ExpressionWrapper::Open(_, open_state) = &mut state { let root_before = open_state.0.root; let taken_root = open_state .take_root_node() .map(|r| NodeIndexInternal(r.0)); //NOTE: post-conditions assert!( open_state.0.root.is_none(), "root should be None after taking it" ); assert_eq!( root_before, taken_root, "root before taking and returned root value from taking should be equal" ); } else { unreachable!(); } } ExpressionTransition::NewVariable(s) => { if let ExpressionWrapper::Open(mapping, open_state) = &mut state { let var = open_state .insert_variable(&s) .expect("inserting a variable should not fail"); mapping.insert(ref_state.nodes.len() - 1, var); let position = open_state.0.interns.iter().position(|t| t == s.as_str()).expect("the interns list should contain the name of the inserted variable"); assert_matches!( open_state.0.inner[var.0], NodeInternal::Variable(position), "the variable node should point to the position of the equivalent string in the intern list" ); } else { unreachable!(); } } ExpressionTransition::NewScalar(v) => { if let ExpressionWrapper::Open(mapping, open_state) = &mut state { let scalar = open_state .insert_scalar(v) .expect("inserting a scalar should not fail"); mapping.insert(ref_state.nodes.len() - 1, scalar); assert_matches!( &open_state.0.inner[scalar.0], NodeInternal::Scalar(v), "the scalar node should contain the same value as was inserted" ); } else { unreachable!(); } } ExpressionTransition::NewOperation(s, c) => { if let ExpressionWrapper::Open(mapping, open_state) = &mut state { let operation = open_state .insert_operation( &s, c.iter().copied().map(|i| mapping[&i]), ) .expect("inserting an operation should not fail"); mapping.insert(ref_state.nodes.len() - 1, operation); let position = open_state.0.interns.iter().position(|t| t == s.as_str()).expect("the interns list should contain the name of the inserted operation"); if c.len() == 2 { assert_matches!( open_state.0.inner[operation.0], NodeInternal::BinaryOperation(position, _, _), "the operation node should point to the position of the equivalent string in the intern list", ); } else { assert_matches!( open_state.0.inner[operation.0], NodeInternal::Operation(position, _), "the operation node should point to the position of the equivalent string in the intern list", ); } //TODO: validate join nodes, child node equality } else { unreachable!(); } } ExpressionTransition::AcquireRoot => { if let ExpressionWrapper::Open(mapping, open_state) = &mut state { let root = open_state.root_node(); assert_eq!( ref_state.root.is_some(), root.is_some(), "both the reference state and system under test must match" ); if let Some(root) = root && let Some(ref_root) = ref_state.root { mapping.insert(ref_root, root); } } else { unreachable!(); } } ExpressionTransition::AcquireChildren(i) => { if let ExpressionWrapper::Open(mapping, open_state) = &mut state { mapping.extend( ref_state.children[&i] .iter() .copied() .zip(open_state.children_of(mapping[&i])), ); } else { unreachable!(); } } } state } fn check_invariants( state: &Self::SystemUnderTest, ref_state: &::State, ) { let expr = match state { ExpressionWrapper::Open(_, expr) => &expr.0, ExpressionWrapper::Closed(expr) => expr, }; let mut deduplicated_interns = expr.interns.clone(); deduplicated_interns.sort(); deduplicated_interns.dedup(); assert!( deduplicated_interns.len() == expr.interns.len(), "there must be no duplicated string interns" ); for node in &expr.inner { match node { NodeInternal::Operation(i, Some(n)) => { assert!( i.0 < expr.interns.len(), "intern index must be valid" ); assert!( n.0 < expr.inner.len(), "node index must be valid" ); } NodeInternal::BinaryOperation(i, n1, n2) => { assert!( i.0 < expr.interns.len(), "intern index must be valid" ); assert!( n1.0 < expr.inner.len(), "node index must be valid" ); assert!( n2.0 < expr.inner.len(), "node index must be valid" ); assert!( !matches!( expr.inner[n1.0], NodeInternal::Join(_, _) ), "binary operations must not reference joins" ); assert!( !matches!( expr.inner[n2.0], NodeInternal::Join(_, _) ), "binary operations must not reference joins" ); } NodeInternal::Join(n1, n2) => { assert!( n1.0 < expr.inner.len(), "node index must be valid" ); assert!( !matches!( expr.inner[n1.0], NodeInternal::Join(_, _) ), "operations must not have joins as children" ); assert!( n2.0 < expr.inner.len(), "node index must be valid" ); } NodeInternal::Variable(i) => { assert!( i.0 < expr.interns.len(), "intern index must be valid" ); } _ => (), } } //TODO: the graph must never contain a cycle. //TODO: validate the layout of the entire expression against the // reference. } } prop_state_machine! { #![proptest_config(ProptestConfig { // allow more rejects to make shrinking results better max_global_rejects: u32::MAX, // allow more shrinking iterations to improve the results max_shrink_iters: u32::MAX, // disable failure persistence so miri works #[cfg(miri)] failure_persistence: None, ..ProptestConfig::default() })] #[test] fn matches_state_machine(sequential 1..1_000 => ExpressionWrapper); } }