diff options
Diffstat (limited to 'crates/core/src/expressions/mod.rs')
-rw-r--r-- | crates/core/src/expressions/mod.rs | 1351 |
1 files changed, 1351 insertions, 0 deletions
diff --git a/crates/core/src/expressions/mod.rs b/crates/core/src/expressions/mod.rs new file mode 100644 index 0000000..c5abebf --- /dev/null +++ b/crates/core/src/expressions/mod.rs @@ -0,0 +1,1351 @@ +//! Common representation of mathematical expressions. + +use alloc::{ + alloc::{Allocator, Global}, + collections::TryReserveError, + string::String, + vec::Vec, +}; +use core::{hint, iter::TrustedLen}; +use generativity::{Guard, Id}; + +use crate::numerics::rational::Rational; + +#[cfg(feature = "core-fmt")] +use core::fmt; + +#[cfg(feature = "core-error")] +use core::error; + +//TODO: consider topological sorting of the expression's internals as this +// may be better for performance. +//TODO: consider garbage collection of unreachable nodes (and their +// children). +//TODO: if we add garbage collection, support marking a node for deletion +// when garbage collection is performed. +//TODO: consider renaming `OpenExpression` to something else. +//TODO: add more thorough documentation, along the lines of the standard +// library's `Vec`, given how this is a very fundamental type. + +/// Representation of a mathematical expression. +/// +/// Components of the expression are maintained in a vector where they are +/// linked to each other by their indices. Indices always point to another +/// existing node. +/// +/// # Guarantees +/// +/// - The expression contains no cycles. +/// - All references to other nodes are valid. +/// - All references to interned strings are valid. +/// - No strings are interned more than once. +#[derive(Clone)] +#[cfg_attr(feature = "core-fmt", derive(Debug))] +pub struct Expression<A = Global> +where + A: Allocator + Clone, +{ + /// List of nodes linked to each other by their indices. + pub(crate) inner: Vec<NodeInternal, A>, + + /// List of interned strings. + pub(crate) interns: Vec<String, A>, + + /// Root node of the expression. + pub(crate) root: Option<NodeIndexInternal>, +} + +impl Expression<Global> { + /// Constructs an empty expression. + pub fn new() -> Self { + Self::new_in(Global) + } +} + +impl<A> Expression<A> +where + A: Allocator + Clone, +{ + /// Constructs an empty expression using the provided [`Allocator`]. + pub fn new_in(alloc: A) -> Self { + Self { + inner: Vec::new_in(alloc.clone()), + interns: Vec::new_in(alloc), + root: None, + } + } + + /// Clears the expression, removing all nodes and associated information. + pub fn clear(&mut self) { + self.inner.clear(); + self.interns.clear(); + self.root = None; + } + + /// Opens up the expression for additive operations and indexing. + /// + /// While in this state, no destructive operations may be done, such as + /// garbage collection and sorting of the node list. + pub fn open<'id>(self, guard: Guard<'id>) -> OpenExpression<'id, A> { + OpenExpression(self, guard.into()) + } + + /// Inserts a [`NodeInternal`] into the node list and returns the index. + /// + /// # Errors + /// + /// Returns an error if storage was unable to be reserved for the new + /// node. + fn insert_node( + &mut self, + node: NodeInternal, + ) -> Result<NodeIndexInternal, Error> { + self.inner.try_reserve(1)?; + //SAFETY: we just reserved space for this element + unsafe { self.inner.push_within_capacity(node).unwrap_unchecked() } + Ok(NodeIndexInternal(self.inner.len() - 1)) + } + + /// Inserts a [`&str`] into the intern list if it doesn't already exist + /// and returns the index. + /// + /// # Errors + /// + /// Returns an error if storage was unable to be reserved for the new + /// string. + fn insert_string( + &mut self, + string: impl AsRef<str>, + ) -> Result<StringIndexInternal, Error> { + if let Some(i) = + self.interns.iter().position(|s| s == string.as_ref()) + { + return Ok(StringIndexInternal(i)); + } + + self.interns.try_reserve(1)?; + //SAFETY: we just reserved space for this element + unsafe { + self.interns + .push_within_capacity(string.as_ref().into()) + .unwrap_unchecked() + } + Ok(StringIndexInternal(self.interns.len() - 1)) + } + + /// Returns the number of nodes within this expression. + pub fn len(&self) -> usize { + self.inner.len() + } + + /// Indicates if the expression has no nodes. + pub fn is_empty(&self) -> bool { + self.inner.is_empty() + } +} + +impl Default for Expression<Global> { + fn default() -> Self { + Self::new() + } +} + +/// Representation of a mathematical expression. +/// +/// This is a variant of [`Expression`] in a state where it is able to be +/// written to and indexed. +#[derive(Clone)] +#[cfg_attr(feature = "core-fmt", derive(Debug))] +pub struct OpenExpression<'id, A>(Expression<A>, Id<'id>) +where + A: Allocator + Clone; + +impl<'id, A> OpenExpression<'id, A> +where + A: Allocator + Clone, +{ + /// Adds a new scalar value to the expression. + /// + /// # Errors + /// + /// Returns an error if storage was unable to be reserved for the new + /// scalar value. + pub fn insert_scalar( + &mut self, + value: impl Into<ScalarValue>, + ) -> Result<NodeIndex<'id>, Error> { + Ok(( + self.0.insert_node(NodeInternal::Scalar(value.into()))?, + self.1, + ) + .into()) + } + + /// Adds a new variable to the expression. + /// + /// # Errors + /// + /// Returns an error if storage was unable to be reserved for the new + /// variable. + pub fn insert_variable( + &mut self, + name: impl AsRef<str>, + ) -> Result<NodeIndex<'id>, Error> { + let name = self.0.insert_string(name)?; + Ok( + (self.0.insert_node(NodeInternal::Variable(name))?, self.1) + .into(), + ) + } + + /// Adds a new operation to the expression. + /// + /// # Errors + /// + /// Returns an error if storage was unable to be reserved for the new + /// operation or if one of the children of the node would have been a + /// [`Node::Join`]. + pub fn insert_operation<Arguments>( + &mut self, + operator: impl AsRef<str>, + arguments: Arguments, + ) -> Result<NodeIndex<'id>, Error> + where + Arguments: IntoIterator<Item = NodeIndex<'id>>, + Arguments::IntoIter: TrustedLen, + { + let operator = self.0.insert_string(operator)?; + let mut arguments = arguments.into_iter(); + let n_arguments = arguments.size_hint().0; + if n_arguments == 2 { + //SAFETY: the iterator is required to implement `TrustedLen` + let first = unsafe { arguments.next().unwrap_unchecked() }; + + //SAFETY: the iterator is required to implement `TrustedLen` + let second = unsafe { arguments.next().unwrap_unchecked() }; + + if matches!( + //SAFETY: lifetime branding ensures this index is valid + unsafe { self.0.inner.get_unchecked(first.0) }, + NodeInternal::Join(_, _) + ) || matches!( + //SAFETY: lifetime branding ensures this index is valid + unsafe { self.0.inner.get_unchecked(second.0) }, + NodeInternal::Join(_, _) + ) { + return Err(Error::InvalidChild); + } + + Ok(( + self.0.insert_node(NodeInternal::BinaryOperation( + operator, + first.into(), + second.into(), + ))?, + self.1, + ) + .into()) + } else { + let number_of_joins = n_arguments.saturating_sub(1); + self.0.inner.try_reserve(number_of_joins + 1)?; + + //SAFETY: the iterator is required to implement `TrustedLen` so + // we will have reserved enough space for these in addition + // to guaranteeing the noted elements exist + // + // additionally, because of our lifetime branding, all + // indices referenced by lifetime-branded node indices are + // valid + unsafe { + self.0 + .inner + .push_within_capacity(match n_arguments { + 0 => NodeInternal::Operation(operator, None), + 1 => { + let argument = + arguments.next().unwrap_unchecked(); + if matches!( + self.0.inner.get_unchecked(argument.0), + NodeInternal::Join(_, _) + ) { + return Err(Error::InvalidChild); + } + + NodeInternal::Operation( + operator, + Some(argument.into()), + ) + } + //NOTE: we know that we will be adding at least one + // more slot to hold an `NodeInternal::Join` + _ => NodeInternal::Operation( + operator, + Some(NodeIndexInternal(self.0.inner.len() + 1)), + ), + }) + .unwrap_unchecked() + } + let operation_index = NodeIndexInternal(self.0.inner.len() - 1); + + //NOTE: we are exclusively inserting `NodeInternal::Join`s now + for (i, argument) in arguments.enumerate() { + if matches!( + //SAFETY: lifetime branding ensures this index is valid + unsafe { self.0.inner.get_unchecked(argument.0) }, + NodeInternal::Join(_, _) + ) { + //NOTE: leaving the inner vector in this state is not + // invalid as nothing will be able to reference any + // of the slots allocated already + // + // when it is added, garbage collection will clean up + // after this if it happens + return Err(Error::InvalidChild); + } + + if i != n_arguments - 1 { + //SAFETY: we know that we have reserved enough space for + // each join by this point + unsafe { + self.0 + .inner + .push_within_capacity(NodeInternal::Join( + argument.into(), + NodeIndexInternal(self.0.inner.len() + 1), + )) + .unwrap_unchecked() + } + } else { + let last_index = self.0.inner.len() - 1; + + if let &mut NodeInternal::Join(_, ref mut second) = + //SAFETY: we know that if we are at this point, + // the current last element exists + unsafe { + self.0.inner.get_unchecked_mut(last_index) + } + { + *second = argument.into(); + } else { + //SAFETY: we know by this point that this last + // element must be `NodeInternal::Join` + unsafe { hint::unreachable_unchecked() } + } + } + } + + Ok((operation_index, self.1).into()) + } + } + + /// Sets the root node of the expression. + pub fn set_root_node(&mut self, node: NodeIndex<'id>) { + self.0.root = Some(node.into()); + } + + /// Closes the expression to additive operations and being indexed. + pub fn close(self) -> Expression<A> { + self.0 + } + + /// Returns the number of nodes within this expression. + pub fn len(&self) -> usize { + self.0.inner.len() + } + + /// Indicates if the expression has no nodes. + pub fn is_empty(&self) -> bool { + self.0.inner.is_empty() + } + + /// Gets the root node of the expression, if it exists. + pub fn root_node(&self) -> Option<NodeIndex<'id>> { + Some((self.0.root?, self.1).into()) + } + + /// Takes the expression's root node. + pub fn take_root_node(&mut self) -> Option<NodeIndex<'id>> { + Some((self.0.root.take()?, self.1).into()) + } + + /// Retrieves the indicated node from the expression. + pub fn node(&self, index: NodeIndex<'id>) -> Node<'id> { + //SAFETY: by the lifetime, this is guaranteed to be a valid index + ( + unsafe { self.0.inner.get_unchecked(index.0) }.clone(), + self.1, + ) + .into() + } + + /// Retrieves the indicated string from the expression. + pub fn string(&self, index: StringIndex<'id>) -> &str { + //SAFETY: by the lifetime, this is guaranteed to be a valid index + unsafe { self.0.interns.get_unchecked(index.0) } + } + + /// Returns an iterator over the children of the node. + pub fn children_of<'a>( + &'a self, + index: NodeIndex<'id>, + ) -> Children<'a, 'id> { + Children::new(&self.0.inner, index) + } +} + +/// Iterator over the children of a [`Node`]. +pub struct Children<'a, 'id> { + /// The node we are currently at. + /// + /// If the node given to the constructor was a [`Node::Scalar`], + /// [`Node::Variable`], or a [`Node::Operation`] with `None` as the + /// second parameter, this must be `None`, and hence this may never be + /// referencing a node of those kinds. + current: Option<NodeIndexInternal>, + + /// Whether or not we are at the "second" part of the node. + /// + /// This is only relevant for [`Node::BinaryOperation`] as it is not + /// clear which of the children are being referenced. + second: bool, + + /// Reference to the internal storage of the [`Expression`]. + inner: &'a [NodeInternal], + + /// Lifetime tying it to the open state of the [`Expression`]. + id: Id<'id>, +} + +impl<'a, 'id> Children<'a, 'id> { + /// Creates a new iterator over the children of a [`Node`]. + fn new(inner: &'a [NodeInternal], index: NodeIndex<'id>) -> Self { + let id = index.1; + if matches!( + //SAFETY: by the lifetime, this is guaranteed to be a valid + // index + unsafe { inner.get_unchecked(index.0) }, + NodeInternal::Scalar(_) + | NodeInternal::Variable(_) + | NodeInternal::Operation(_, None) + ) { + Self { + current: None, + second: false, + inner, + id, + } + } else { + Self { + current: Some(index.into()), + second: false, + inner, + id, + } + } + } +} + +impl<'a, 'id> Iterator for Children<'a, 'id> { + type Item = NodeIndex<'id>; + + fn next(&mut self) -> Option<Self::Item> { + if let Some(current) = self.current { + //SAFETY: by the lifetime, this is guaranteed to be a valid index + let current = unsafe { self.inner.get_unchecked(current.0) }; + Some( + ( + match *current { + NodeInternal::Operation(_, Some(index)) => { + //SAFETY: all internal indices are guaranteed to + // be valid + let next = + unsafe { self.inner.get_unchecked(index.0) }; + + if let NodeInternal::Join(child_index, _) = *next + { + self.current = Some(index); + child_index + } else { + self.current = None; + index + } + } + NodeInternal::Join(_, index) => { + //SAFETY: all internal indices are guaranteed to + // be valid + let next = + unsafe { self.inner.get_unchecked(index.0) }; + + if let NodeInternal::Join(child_index, _) = *next + { + self.current = Some(index); + child_index + } else { + self.current = None; + index + } + } + NodeInternal::BinaryOperation(_, index, _) + if !self.second => + { + self.second = true; + index + } + NodeInternal::BinaryOperation(_, _, index) => { + self.current = None; + index + } + _ => unreachable!(), + }, + self.id, + ) + .into(), + ) + } else { + None + } + } +} + +/// Index into the node vector of [`Expression`]. +#[repr(transparent)] +#[derive(Eq, PartialEq, Copy, Clone, Hash)] +#[cfg_attr(feature = "core-fmt", derive(Debug))] +//NOTE: false positive, see https://github.com/CAD97/generativity/issues/13 +#[allow(repr_transparent_external_private_fields)] +pub struct NodeIndex<'id>(usize, Id<'id>); + +impl<'id> From<(NodeIndexInternal, Id<'id>)> for NodeIndex<'id> { + fn from((index, id): (NodeIndexInternal, Id<'id>)) -> Self { + Self(index.0, id) + } +} + +/// Index into the node vector of [`Expression`]. +/// +/// This is only for internal use to avoid the presence of explicit lifetimes +/// on [`Expression`] members. +#[repr(transparent)] +#[derive(Eq, PartialEq, Copy, Clone, Hash)] +#[cfg_attr(feature = "core-fmt", derive(Debug))] +pub(crate) struct NodeIndexInternal(usize); + +impl From<NodeIndex<'_>> for NodeIndexInternal { + fn from(NodeIndex(index, _): NodeIndex<'_>) -> Self { + Self(index) + } +} + +/// Index into the string intern vector of [`Expression`]. +#[repr(transparent)] +#[derive(Eq, PartialEq, Copy, Clone, Hash)] +#[cfg_attr(feature = "core-fmt", derive(Debug))] +//NOTE: false positive, see https://github.com/CAD97/generativity/issues/13 +#[allow(repr_transparent_external_private_fields)] +pub struct StringIndex<'id>(usize, Id<'id>); + +impl<'id> From<(StringIndexInternal, Id<'id>)> for StringIndex<'id> { + fn from((index, id): (StringIndexInternal, Id<'id>)) -> Self { + Self(index.0, id) + } +} + +/// Index into the string intern vector of [`Expression`]. +/// +/// This is only for internal use to avoid the presence of explicit lifetimes +/// on [`Expression`] members. +#[repr(transparent)] +#[derive(Eq, PartialEq, Copy, Clone, Hash)] +#[cfg_attr(feature = "core-fmt", derive(Debug))] +pub(crate) struct StringIndexInternal(usize); + +impl From<StringIndex<'_>> for StringIndexInternal { + fn from(StringIndex(index, _): StringIndex<'_>) -> Self { + Self(index) + } +} + +/// Internal representation of an expression node. +#[non_exhaustive] +#[derive(PartialEq, Clone)] +#[cfg_attr(feature = "core-fmt", derive(Debug))] +pub(crate) enum NodeInternal { + /// Operation over other nodes. + /// + /// The first component indicates the kind of operation and the second is + /// the first argument. Several arguments may be represented through the + /// use of `Join`. + Operation(StringIndexInternal, Option<NodeIndexInternal>), + + /// Binary operation. + /// + /// This is special cased due to how common they are. `Join`s are not + /// permitted to be referenced by the indices here. + BinaryOperation( + StringIndexInternal, + NodeIndexInternal, + NodeIndexInternal, + ), + + /// Joins one `NodeInternal` with another in a sequence, with the first + /// coming before the second. + /// + /// The first `NodeIndexInternal` must not be a `Join`. + Join(NodeIndexInternal, NodeIndexInternal), + + /// Scalar value. + Scalar(ScalarValue), + + /// Variable. + Variable(StringIndexInternal), +} + +/// Node in an expression. +#[non_exhaustive] +#[derive(PartialEq, Clone)] +#[cfg_attr(feature = "core-fmt", derive(Debug))] +pub enum Node<'id> { + /// Operation over other nodes. + /// + /// The first component indicates the kind of operation and the second is + /// the first argument. Several arguments may be represented through the + /// use of `Join`. + Operation(StringIndex<'id>, Option<NodeIndex<'id>>), + + /// Binary operation. + /// + /// This is special cased due to how common they are. `Join`s are not + /// permitted to be referenced by the indices here. + BinaryOperation(StringIndex<'id>, NodeIndex<'id>, NodeIndex<'id>), + + /// Joins one `NodeInternal` with another in a sequence, with the first + /// coming before the second. + /// + /// The first `NodeIndexInternal` must not be a `Join`. + Join(NodeIndex<'id>, NodeIndex<'id>), + + /// Scalar value. + Scalar(ScalarValue), + + /// Variable. + Variable(StringIndex<'id>), +} + +impl<'id> From<(NodeInternal, Id<'id>)> for Node<'id> { + fn from((node, id): (NodeInternal, Id<'id>)) -> Self { + match node { + NodeInternal::Operation(s, n) => { + Node::Operation((s, id).into(), n.map(|n| (n, id).into())) + } + NodeInternal::BinaryOperation(s, n1, n2) => { + Node::BinaryOperation( + (s, id).into(), + (n1, id).into(), + (n2, id).into(), + ) + } + NodeInternal::Join(n1, n2) => { + Node::Join((n1, id).into(), (n2, id).into()) + } + NodeInternal::Scalar(s) => Node::Scalar(s), + NodeInternal::Variable(s) => Node::Variable((s, id).into()), + } + } +} + +/// Scalar value. +#[non_exhaustive] +#[derive(PartialEq, Clone)] +#[cfg_attr(feature = "core-fmt", derive(Debug))] +pub enum ScalarValue { + /// Unsigned integer. + UnsignedInteger(u64), + + /// Signed integer. + Integer(i64), + + /// Floating point value. + Float(f64), + + /// Rational number with unsigned integer components. + UnsignedIntegerRational(Rational<u64>), + + /// Rational number with signed integer components. + SignedIntegerRational(Rational<i64>), +} + +/// Representation of an error that occurred within [`Expression`]. +#[non_exhaustive] +#[derive(Eq, PartialEq)] +#[cfg_attr(feature = "core-fmt", derive(Debug))] +pub enum Error { + /// Memory reservation error. + TryReserveError(TryReserveError), + + /// Invalid child node. + /// + /// This occurs when you attempt to pass an index that points to a + /// [`Node::Join`] as a child of an operation. + InvalidChild, +} + +impl From<TryReserveError> for Error { + fn from(e: TryReserveError) -> Self { + Self::TryReserveError(e) + } +} + +#[cfg(feature = "core-fmt")] +impl fmt::Display for Error { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + let reason = match self { + Error::TryReserveError(_) => "unable to allocate memory", + Error::InvalidChild => "invalid child node", + }; + fmt.write_str(reason) + } +} + +#[cfg(feature = "core-error")] +impl error::Error for Error {} + +#[cfg(all(test, feature = "core-error"))] +mod test { + #![allow(clippy::expect_used)] + + use super::*; + + use generativity::make_guard; + use proptest::prelude::*; + use proptest_state_machine::{ + ReferenceStateMachine, StateMachineTest, prop_state_machine, + }; + use std::{assert_matches::assert_matches, collections::HashMap}; + + #[test] + fn simple() { + make_guard!(g); + let mut expr = Expression::new().open(g); + + let a = expr + .insert_variable("apples") + .expect("unable to insert a variable"); + assert_eq!(expr.children_of(a).collect::<Vec<_>>().as_slice(), &[]); + + let b = expr + .insert_variable("bananas") + .expect("unable to insert a variable"); + let c = expr + .insert_variable("oranges") + .expect("unable to insert a variable"); + + let sin = expr + .insert_operation("sin", [a]) + .expect("unable to insert an operation"); + assert_eq!( + expr.children_of(sin).collect::<Vec<_>>().as_slice(), + &[a] + ); + + let sum = expr + .insert_operation("sum", [a, b, c]) + .expect("unable to insert an operation"); + assert_eq!( + expr.children_of(sum).collect::<Vec<_>>().as_slice(), + &[a, b, c] + ); + + let product = expr + .insert_operation("*", [sin, sum]) + .expect("unable to insert an operation"); + assert_eq!( + expr.children_of(product).collect::<Vec<_>>().as_slice(), + &[sin, sum] + ); + + expr.close(); + } + + /// Reference state machine for [`Expression`]s. + /// + /// Models [`Expression`]s as a root node index, a vector of nodes, a + /// mapping of parent node indices to a vector of child node indices, + /// a vector of valid indices, and whether or not it is currently open. + struct ExpressionReference; + + /// State used by the reference state machine. + #[derive(Clone, Debug, Default)] + struct ExpressionReferenceState { + /// Index of the root node in the node vector. + pub root: Option<usize>, + + /// Vector of nodes in the expression. + pub nodes: Vec<NodeRep>, + + /// Mapping of parent node indices to child node indices. + pub children: HashMap<usize, Vec<usize>>, + + /// The current list of valid indices. + pub valid_indices: Vec<usize>, + + /// Whether or not the expression is open. + /// + /// If it is open, then additive operations and indexing is allowed. + /// Otherwise, operations that may invalidate indices such as node + /// removal, sorting, etc are allowed. + pub is_open: bool, + } + + /// Simplified [`NodeInternal`] representation for the [`HashMap`] backed + /// graph. + #[derive(Clone, Debug)] + enum NodeRep { + /// Operation over other nodes. + Operation(String), + + /// Scalar value. + Scalar(ScalarValue), + + /// Variable. + Variable(String), + } + + /// Possible transitions for the state machine. + #[derive(Clone, Debug)] + enum ExpressionTransition { + /// Open the expression. + Open, + + /// Add a new operation to the expression. + NewOperation(String, Vec<usize>), + + /// Add a new scalar to the expression. + NewScalar(ScalarValue), + + /// Add a new variable to the expression. + NewVariable(String), + + /// Set the root node of the expression. + SetRoot(usize), + + /// Remove the root of the expression. + TakeRoot, + + /// Acquire a valid index to the root of expression. + AcquireRoot, + + /// Acquire valid indices to the children of the given node. + AcquireChildren(usize), + + /// Close the expression, invalidating all currently held indices. + Close, + + /// Remove all nodes from the expression. + Clear, + } + + impl ExpressionTransition { + /// Indicates if this transition requires that the expression be open. + fn needs_open(&self) -> bool { + matches!( + self, + ExpressionTransition::NewOperation(_, _) + | ExpressionTransition::NewScalar(_) + | ExpressionTransition::NewVariable(_) + | ExpressionTransition::SetRoot(_) + | ExpressionTransition::TakeRoot + | ExpressionTransition::AcquireRoot + | ExpressionTransition::AcquireChildren(_) + | ExpressionTransition::Close + ) + } + } + + impl ReferenceStateMachine for ExpressionReference { + type State = ExpressionReferenceState; + type Transition = ExpressionTransition; + + fn init_state() -> BoxedStrategy<Self::State> { + //TODO: randomly initialize state + Just(Self::State::default()).boxed() + } + + fn transitions( + state: &Self::State, + ) -> BoxedStrategy<Self::Transition> { + if state.is_open && state.valid_indices.is_empty() { + let n_valid = state.valid_indices.len(); + prop_oneof![ + 1 => Just(ExpressionTransition::Close), + 1 => Just(ExpressionTransition::TakeRoot), + 5 => Just(ExpressionTransition::AcquireRoot), + //NOTE: what a scalar value actually is has zero + // bearing on anything so we can just yield a static + // value here without any problems. in the future we + // may have something like a common subexpression + // elimination feature, which would make varying it + // valuable, but for now, we have no reason to do + // anything other than adding a static scalar. this + // is different for variables because string + // interning exists, so varying those does actually + // have an effect + 7 => Just(ExpressionTransition::NewScalar( + ScalarValue::UnsignedInteger(1) + )), + // 702 possible identifiers is probably enough + 7 => "[a-z]{1,2}" + .prop_map(ExpressionTransition::NewVariable), + ] + .boxed() + } else if state.is_open { + let n_valid = state.valid_indices.len(); + prop_oneof![ + 1 => Just(ExpressionTransition::Close), + 1 => Just(ExpressionTransition::TakeRoot), + 5 => Just(ExpressionTransition::AcquireRoot), + //NOTE: what a scalar value actually is has zero + // bearing on anything so we can just yield a static + // value here without any problems. in the future we + // may have something like a common subexpression + // elimination feature, which would make varying it + // valuable, but for now, we have no reason to do + // anything other than adding a static scalar. this + // is different for variables because string + // interning exists, so varying those does actually + // have an effect + 7 => Just(ExpressionTransition::NewScalar( + ScalarValue::UnsignedInteger(1) + )), + // 702 possible identifiers is probably enough + 7 => "[a-z]{1,2}" + .prop_map(ExpressionTransition::NewVariable), + 9 => prop::sample::subsequence( + state.valid_indices.clone(), + 0..=n_valid + ) + //NOTE: 702 possible identifiers is + // probably enough + .prop_flat_map(|v| ("[A-Z]{1,2}", Just(v))) + .prop_map(|(s, v)| + ExpressionTransition::NewOperation(s, v) + ), + 1 => prop::sample::select(state.valid_indices.clone()) + .prop_map(ExpressionTransition::SetRoot), + 5 => prop::sample::select(state.valid_indices.clone()) + .prop_map(ExpressionTransition::AcquireChildren), + ] + .boxed() + } else { + prop_oneof![ + 1 => Just(ExpressionTransition::Clear), + 9 => Just(ExpressionTransition::Open), + ] + .boxed() + } + } + + fn preconditions( + state: &Self::State, + transition: &Self::Transition, + ) -> bool { + if state.is_open ^ transition.needs_open() { + return false; + } + + match transition { + ExpressionTransition::NewOperation(_, children) => { + for c in children { + if *c >= state.nodes.len() + || !state.valid_indices.contains(c) + { + return false; + } + } + } + ExpressionTransition::SetRoot(node) + | ExpressionTransition::AcquireChildren(node) => { + if *node >= state.nodes.len() + || !state.valid_indices.contains(node) + { + return false; + } + } + _ => (), + } + true + } + + fn apply( + mut state: Self::State, + transition: &Self::Transition, + ) -> Self::State { + match transition { + ExpressionTransition::Open => state.is_open = true, + ExpressionTransition::Clear => { + state.nodes.clear(); + state.children.clear(); + state.root = None; + } + ExpressionTransition::Close => { + state.valid_indices.clear(); + state.is_open = false; + } + ExpressionTransition::SetRoot(index) => { + state.root = Some(*index); + } + ExpressionTransition::TakeRoot => { + state.root = None; + } + ExpressionTransition::NewVariable(s) => { + state.nodes.push(NodeRep::Variable(s.clone())); + let i = state.nodes.len() - 1; + state.valid_indices.push(i); + state.children.insert(i, Vec::new()); + } + ExpressionTransition::NewScalar(v) => { + state.nodes.push(NodeRep::Scalar(v.clone())); + let i = state.nodes.len() - 1; + state.valid_indices.push(i); + state.children.insert(i, Vec::new()); + } + ExpressionTransition::AcquireRoot => { + if let Some(root) = state.root { + state.valid_indices.push(root); + } + } + ExpressionTransition::AcquireChildren(i) => { + state.valid_indices.extend(&state.children[i]); + } + ExpressionTransition::NewOperation(s, c) => { + state.nodes.push(NodeRep::Operation(s.clone())); + let i = state.nodes.len() - 1; + state.children.insert(i, c.clone()); + state.valid_indices.push(i); + } + } + state + } + } + + /// Wrapper around an [`Expression`] indicating if it is open or not. + enum ExpressionWrapper<'id, A = Global> + where + A: Allocator + Clone, + { + /// Closed expression.. + Closed(Expression<A>), + + /// Open expression and a mapping from the state machine's indices to + /// the [`Expression`]'s. + Open(HashMap<usize, NodeIndex<'id>>, OpenExpression<'id, A>), + } + + impl<'id> StateMachineTest for ExpressionWrapper<'id> { + type SystemUnderTest = Self; + type Reference = ExpressionReference; + + fn init_test( + _ref_state: &<Self::Reference as ReferenceStateMachine>::State, + ) -> Self::SystemUnderTest { + //TODO: randomly initialize state in the reference and replicate + // it over here. + ExpressionWrapper::Closed(Expression::new()) + } + + fn apply( + mut state: Self::SystemUnderTest, + ref_state: &<Self::Reference as ReferenceStateMachine>::State, + transition: ExpressionTransition, + ) -> Self::SystemUnderTest { + match transition { + ExpressionTransition::Open => { + if let ExpressionWrapper::Closed(closed_state) = state { + //SAFETY: unfortunately, we have to do this, + // otherwise this test cannot work out due to + // lifetime issues + let guard = unsafe { Guard::new(Id::new()) }; + state = ExpressionWrapper::Open( + HashMap::new(), + closed_state.open(guard), + ); + } else { + unreachable!(); + } + } + ExpressionTransition::Clear => { + if let ExpressionWrapper::Closed(closed_state) = + &mut state + { + closed_state.clear(); + + //NOTE: post-conditions + assert!( + closed_state.inner.is_empty(), + "inner node vector should be empty after clearing" + ); + assert!( + closed_state.interns.is_empty(), + "intern vector should be empty after clearing" + ); + assert!( + closed_state.root.is_none(), + "root node should not exist after clearing" + ); + } else { + unreachable!(); + } + } + ExpressionTransition::Close => { + if let ExpressionWrapper::Open(_, open_state) = state { + state = ExpressionWrapper::Closed(open_state.close()); + } else { + unreachable!(); + } + } + ExpressionTransition::SetRoot(index) => { + if let ExpressionWrapper::Open(mapping, open_state) = + &mut state + { + open_state.set_root_node(mapping[&index]); + + //NOTE: post-conditions + assert_eq!( + open_state.0.root, + Some(NodeIndexInternal(mapping[&index].0)), + "root should be set to the given index after setting it" + ); + } else { + unreachable!(); + } + } + ExpressionTransition::TakeRoot => { + if let ExpressionWrapper::Open(_, open_state) = &mut state + { + let root_before = open_state.0.root; + let taken_root = open_state + .take_root_node() + .map(|r| NodeIndexInternal(r.0)); + + //NOTE: post-conditions + assert!( + open_state.0.root.is_none(), + "root should be None after taking it" + ); + assert_eq!( + root_before, taken_root, + "root before taking and returned root value from taking should be equal" + ); + } else { + unreachable!(); + } + } + ExpressionTransition::NewVariable(s) => { + if let ExpressionWrapper::Open(mapping, open_state) = + &mut state + { + let var = open_state + .insert_variable(&s) + .expect("inserting a variable should not fail"); + mapping.insert(ref_state.nodes.len() - 1, var); + + let position = open_state.0.interns.iter().position(|t| t == s.as_str()).expect("the interns list should contain the name of the inserted variable"); + assert_matches!( + open_state.0.inner[var.0], + NodeInternal::Variable(position), + "the variable node should point to the position of the equivalent string in the intern list" + ); + } else { + unreachable!(); + } + } + ExpressionTransition::NewScalar(v) => { + if let ExpressionWrapper::Open(mapping, open_state) = + &mut state + { + let scalar = open_state + .insert_scalar(v) + .expect("inserting a scalar should not fail"); + mapping.insert(ref_state.nodes.len() - 1, scalar); + assert_matches!( + &open_state.0.inner[scalar.0], + NodeInternal::Scalar(v), + "the scalar node should contain the same value as was inserted" + ); + } else { + unreachable!(); + } + } + ExpressionTransition::NewOperation(s, c) => { + if let ExpressionWrapper::Open(mapping, open_state) = + &mut state + { + let operation = open_state + .insert_operation( + &s, + c.iter().copied().map(|i| mapping[&i]), + ) + .expect("inserting an operation should not fail"); + mapping.insert(ref_state.nodes.len() - 1, operation); + + let position = open_state.0.interns.iter().position(|t| t == s.as_str()).expect("the interns list should contain the name of the inserted operation"); + if c.len() == 2 { + assert_matches!( + open_state.0.inner[operation.0], + NodeInternal::BinaryOperation(position, _, _), + "the operation node should point to the position of the equivalent string in the intern list", + ); + } else { + assert_matches!( + open_state.0.inner[operation.0], + NodeInternal::Operation(position, _), + "the operation node should point to the position of the equivalent string in the intern list", + ); + } + //TODO: validate join nodes, child node equality + } else { + unreachable!(); + } + } + ExpressionTransition::AcquireRoot => { + if let ExpressionWrapper::Open(mapping, open_state) = + &mut state + { + let root = open_state.root_node(); + assert_eq!( + ref_state.root.is_some(), + root.is_some(), + "both the reference state and system under test must match" + ); + + if let Some(root) = root + && let Some(ref_root) = ref_state.root + { + mapping.insert(ref_root, root); + } + } else { + unreachable!(); + } + } + ExpressionTransition::AcquireChildren(i) => { + if let ExpressionWrapper::Open(mapping, open_state) = + &mut state + { + mapping.extend( + ref_state.children[&i] + .iter() + .copied() + .zip(open_state.children_of(mapping[&i])), + ); + } else { + unreachable!(); + } + } + } + state + } + + fn check_invariants( + state: &Self::SystemUnderTest, + ref_state: &<Self::Reference as ReferenceStateMachine>::State, + ) { + let expr = match state { + ExpressionWrapper::Open(_, expr) => &expr.0, + ExpressionWrapper::Closed(expr) => expr, + }; + + let mut deduplicated_interns = expr.interns.clone(); + deduplicated_interns.sort(); + deduplicated_interns.dedup(); + assert!( + deduplicated_interns.len() == expr.interns.len(), + "there must be no duplicated string interns" + ); + + for node in &expr.inner { + match node { + NodeInternal::Operation(i, Some(n)) => { + assert!( + i.0 < expr.interns.len(), + "intern index must be valid" + ); + assert!( + n.0 < expr.inner.len(), + "node index must be valid" + ); + } + NodeInternal::BinaryOperation(i, n1, n2) => { + assert!( + i.0 < expr.interns.len(), + "intern index must be valid" + ); + assert!( + n1.0 < expr.inner.len(), + "node index must be valid" + ); + assert!( + n2.0 < expr.inner.len(), + "node index must be valid" + ); + assert!( + !matches!( + expr.inner[n1.0], + NodeInternal::Join(_, _) + ), + "binary operations must not reference joins" + ); + assert!( + !matches!( + expr.inner[n2.0], + NodeInternal::Join(_, _) + ), + "binary operations must not reference joins" + ); + } + NodeInternal::Join(n1, n2) => { + assert!( + n1.0 < expr.inner.len(), + "node index must be valid" + ); + assert!( + !matches!( + expr.inner[n1.0], + NodeInternal::Join(_, _) + ), + "operations must not have joins as children" + ); + assert!( + n2.0 < expr.inner.len(), + "node index must be valid" + ); + } + NodeInternal::Variable(i) => { + assert!( + i.0 < expr.interns.len(), + "intern index must be valid" + ); + } + _ => (), + } + } + + //TODO: the graph must never contain a cycle. + //TODO: validate the layout of the entire expression against the + // reference. + } + } + + prop_state_machine! { + #![proptest_config(ProptestConfig { + // allow more rejects to make shrinking results better + max_global_rejects: 1_000_000, + + // disable failure persistence so miri works + #[cfg(miri)] + failure_persistence: None, + ..ProptestConfig::default() + })] + + #[test] + fn matches_state_machine(sequential 1..100 => ExpressionWrapper); + } +} |