diff options
-rw-r--r-- | crates/core/src/expressions/mod.rs | 141 | ||||
-rw-r--r-- | crates/core/src/hive/mod.rs | 6 | ||||
-rw-r--r-- | crates/core/src/numerics/rational.rs | 16 |
3 files changed, 151 insertions, 12 deletions
diff --git a/crates/core/src/expressions/mod.rs b/crates/core/src/expressions/mod.rs index c5abebf..902ead1 100644 --- a/crates/core/src/expressions/mod.rs +++ b/crates/core/src/expressions/mod.rs @@ -4,6 +4,7 @@ use alloc::{ alloc::{Allocator, Global}, collections::TryReserveError, string::String, + vec, vec::Vec, }; use core::{hint, iter::TrustedLen}; @@ -150,6 +151,106 @@ impl Default for Expression<Global> { } } +#[cfg(feature = "core-fmt")] +impl<A> fmt::Display for Expression<A> +where + A: Allocator + Clone, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + /// State we are in when we are at a node. + #[derive(PartialEq, Eq)] + enum State { + /// First time encountering a node. + Enter, + + /// Second time encountering a node. + Visit, + + /// Third time encountering a node. + Close, + } + + if let Some(root_index) = self.root { + let mut stack = vec![(root_index, State::Enter)]; + + while let Some((node_index, state)) = stack.pop() { + //SAFETY: all internal indices are guaranteed to be valid + match *unsafe { self.inner.get_unchecked(node_index.0) } { + NodeInternal::BinaryOperation(_, _, _) + if state == State::Close => + { + f.write_str(")")?; + } + NodeInternal::BinaryOperation( + StringIndexInternal(s), + _, + _, + ) if state == State::Visit => { + //SAFETY: all internal indices are guaranteed to be + // valid + write!(f, " {} ", unsafe { + self.interns.get_unchecked(s) + })?; + } + NodeInternal::BinaryOperation(_, c1, c2) + if state == State::Enter => + { + f.write_str("(")?; + stack.push((node_index, State::Close)); + stack.push((c2, State::Enter)); + stack.push((node_index, State::Visit)); + stack.push((c1, State::Enter)); + } + NodeInternal::Operation(StringIndexInternal(s), None) => { + //SAFETY: all internal indices are guaranteed to be + // valid + write!(f, "{}()", unsafe { + self.interns.get_unchecked(s) + })?; + } + NodeInternal::Operation(_, _) + if state == State::Close => + { + f.write_str(")")?; + } + NodeInternal::Operation( + StringIndexInternal(s), + Some(c1), + ) => { + //SAFETY: all internal indices are guaranteed to be + // valid + write!(f, "{}(", unsafe { + self.interns.get_unchecked(s) + })?; + stack.push((node_index, State::Close)); + stack.push((c1, State::Enter)); + } + NodeInternal::Join(_, _) if state == State::Visit => { + f.write_str(", ")?; + } + NodeInternal::Join(c1, c2) => { + stack.push((c2, State::Enter)); + stack.push((node_index, State::Visit)); + stack.push((c1, State::Enter)); + } + NodeInternal::Scalar(ref v) => { + write!(f, "{}", v)?; + } + NodeInternal::Variable(StringIndexInternal(s)) => { + //SAFETY: all internal indices are guaranteed to be + // valid + write!(f, "{}", unsafe { + self.interns.get_unchecked(s) + })?; + } + _ => unreachable!(), + } + } + } + Ok(()) + } +} + /// Representation of a mathematical expression. /// /// This is a variant of [`Expression`] in a state where it is able to be @@ -393,6 +494,16 @@ where } } +#[cfg(feature = "core-fmt")] +impl<A> fmt::Display for OpenExpression<'_, A> +where + A: Allocator + Clone, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + /// Iterator over the children of a [`Node`]. pub struct Children<'a, 'id> { /// The node we are currently at. @@ -674,6 +785,19 @@ pub enum ScalarValue { SignedIntegerRational(Rational<i64>), } +#[cfg(feature = "core-fmt")] +impl fmt::Display for ScalarValue { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::UnsignedInteger(v) => write!(f, "{}", v), + Self::Integer(v) => write!(f, "{}", v), + Self::Float(v) => write!(f, "{}", v), + Self::UnsignedIntegerRational(v) => write!(f, "{}", v), + Self::SignedIntegerRational(v) => write!(f, "{}", v), + } + } +} + /// Representation of an error that occurred within [`Expression`]. #[non_exhaustive] #[derive(Eq, PartialEq)] @@ -697,12 +821,12 @@ impl From<TryReserveError> for Error { #[cfg(feature = "core-fmt")] impl fmt::Display for Error { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let reason = match self { Error::TryReserveError(_) => "unable to allocate memory", Error::InvalidChild => "invalid child node", }; - fmt.write_str(reason) + f.write_str(reason) } } @@ -763,7 +887,10 @@ mod test { &[sin, sum] ); - expr.close(); + expr.set_root_node(product); + let expr = expr.close(); + + println!("{}", expr); } /// Reference state machine for [`Expression`]s. @@ -874,7 +1001,6 @@ mod test { state: &Self::State, ) -> BoxedStrategy<Self::Transition> { if state.is_open && state.valid_indices.is_empty() { - let n_valid = state.valid_indices.len(); prop_oneof![ 1 => Just(ExpressionTransition::Close), 1 => Just(ExpressionTransition::TakeRoot), @@ -1337,7 +1463,10 @@ mod test { prop_state_machine! { #![proptest_config(ProptestConfig { // allow more rejects to make shrinking results better - max_global_rejects: 1_000_000, + max_global_rejects: u32::MAX, + + // allow more shrinking iterations to improve the results + max_shrink_iters: u32::MAX, // disable failure persistence so miri works #[cfg(miri)] @@ -1346,6 +1475,6 @@ mod test { })] #[test] - fn matches_state_machine(sequential 1..100 => ExpressionWrapper); + fn matches_state_machine(sequential 1..1_000 => ExpressionWrapper); } } diff --git a/crates/core/src/hive/mod.rs b/crates/core/src/hive/mod.rs index ad14725..3f83099 100644 --- a/crates/core/src/hive/mod.rs +++ b/crates/core/src/hive/mod.rs @@ -297,8 +297,8 @@ pub enum Error { #[cfg(feature = "core-fmt")] impl fmt::Display for Error { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { - fmt.write_str("memory allocation failed")?; + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("memory allocation failed")?; let reason = match self { Error::CapacityOverflow => { " because the computed capacity exceeded the hive's maximum" @@ -307,7 +307,7 @@ impl fmt::Display for Error { " because the memory allocator returned an error" } }; - fmt.write_str(reason) + f.write_str(reason) } } diff --git a/crates/core/src/numerics/rational.rs b/crates/core/src/numerics/rational.rs index e814dbd..173e6f6 100644 --- a/crates/core/src/numerics/rational.rs +++ b/crates/core/src/numerics/rational.rs @@ -106,6 +106,16 @@ where } } +#[cfg(feature = "core-fmt")] +impl<T> fmt::Display for Rational<T> +where + T: Integer + fmt::Display, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}/{}", self.numerator, self.denominator) + } +} + /// Representation of an error that occurred within [`Rational`]. #[non_exhaustive] #[derive(Eq, PartialEq)] @@ -117,12 +127,12 @@ pub enum Error { #[cfg(feature = "core-fmt")] impl fmt::Display for Error { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { - fmt.write_str("rational number construction failed")?; + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("rational number construction failed")?; let reason = match self { Error::ZeroDenominator => " because the denominator was zero", }; - fmt.write_str(reason) + f.write_str(reason) } } |