about summary refs log tree commit diff stats
diff options
context:
space:
mode:
authorsuperwhiskers <[email protected]>2025-09-14 08:06:27 -0500
committersuperwhiskers <[email protected]>2025-09-15 10:55:10 -0500
commit50192cbe91da765d3c09cf8e12f328b721d3cb46 (patch)
tree345e69ef1141c26774677982fe5eaba875dbbbe0
parent83751efd734999fc11316a66317250ca53e76726 (diff)
downloadazimuth-canon.tar.gz
azimuth-canon.tar.bz2
azimuth-canon.zip
add a `Display` implementation to `Expression` HEAD canon
this change adds a string formatter for `Expression`s using the `Display`
trait. additionally, it standardizes the way `Display` implementations
are written and makes some minor adjustments to the parameters used for
the proptest-based test for `Expression`.

Change-Id: I6a6a6964cd5c04e95341a499dcd73297ca2f514a
-rw-r--r--crates/core/src/expressions/mod.rs141
-rw-r--r--crates/core/src/hive/mod.rs6
-rw-r--r--crates/core/src/numerics/rational.rs16
3 files changed, 151 insertions, 12 deletions
diff --git a/crates/core/src/expressions/mod.rs b/crates/core/src/expressions/mod.rs
index c5abebf..902ead1 100644
--- a/crates/core/src/expressions/mod.rs
+++ b/crates/core/src/expressions/mod.rs
@@ -4,6 +4,7 @@ use alloc::{
     alloc::{Allocator, Global},
     collections::TryReserveError,
     string::String,
+    vec,
     vec::Vec,
 };
 use core::{hint, iter::TrustedLen};
@@ -150,6 +151,106 @@ impl Default for Expression<Global> {
     }
 }
 
+#[cfg(feature = "core-fmt")]
+impl<A> fmt::Display for Expression<A>
+where
+    A: Allocator + Clone,
+{
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        /// State we are in when we are at a node.
+        #[derive(PartialEq, Eq)]
+        enum State {
+            /// First time encountering a node.
+            Enter,
+
+            /// Second time encountering a node.
+            Visit,
+
+            /// Third time encountering a node.
+            Close,
+        }
+
+        if let Some(root_index) = self.root {
+            let mut stack = vec![(root_index, State::Enter)];
+
+            while let Some((node_index, state)) = stack.pop() {
+                //SAFETY: all internal indices are guaranteed to be valid
+                match *unsafe { self.inner.get_unchecked(node_index.0) } {
+                    NodeInternal::BinaryOperation(_, _, _)
+                        if state == State::Close =>
+                    {
+                        f.write_str(")")?;
+                    }
+                    NodeInternal::BinaryOperation(
+                        StringIndexInternal(s),
+                        _,
+                        _,
+                    ) if state == State::Visit => {
+                        //SAFETY: all internal indices are guaranteed to be
+                        //        valid
+                        write!(f, " {} ", unsafe {
+                            self.interns.get_unchecked(s)
+                        })?;
+                    }
+                    NodeInternal::BinaryOperation(_, c1, c2)
+                        if state == State::Enter =>
+                    {
+                        f.write_str("(")?;
+                        stack.push((node_index, State::Close));
+                        stack.push((c2, State::Enter));
+                        stack.push((node_index, State::Visit));
+                        stack.push((c1, State::Enter));
+                    }
+                    NodeInternal::Operation(StringIndexInternal(s), None) => {
+                        //SAFETY: all internal indices are guaranteed to be
+                        //        valid
+                        write!(f, "{}()", unsafe {
+                            self.interns.get_unchecked(s)
+                        })?;
+                    }
+                    NodeInternal::Operation(_, _)
+                        if state == State::Close =>
+                    {
+                        f.write_str(")")?;
+                    }
+                    NodeInternal::Operation(
+                        StringIndexInternal(s),
+                        Some(c1),
+                    ) => {
+                        //SAFETY: all internal indices are guaranteed to be
+                        //        valid
+                        write!(f, "{}(", unsafe {
+                            self.interns.get_unchecked(s)
+                        })?;
+                        stack.push((node_index, State::Close));
+                        stack.push((c1, State::Enter));
+                    }
+                    NodeInternal::Join(_, _) if state == State::Visit => {
+                        f.write_str(", ")?;
+                    }
+                    NodeInternal::Join(c1, c2) => {
+                        stack.push((c2, State::Enter));
+                        stack.push((node_index, State::Visit));
+                        stack.push((c1, State::Enter));
+                    }
+                    NodeInternal::Scalar(ref v) => {
+                        write!(f, "{}", v)?;
+                    }
+                    NodeInternal::Variable(StringIndexInternal(s)) => {
+                        //SAFETY: all internal indices are guaranteed to be
+                        //        valid
+                        write!(f, "{}", unsafe {
+                            self.interns.get_unchecked(s)
+                        })?;
+                    }
+                    _ => unreachable!(),
+                }
+            }
+        }
+        Ok(())
+    }
+}
+
 /// Representation of a mathematical expression.
 ///
 /// This is a variant of [`Expression`] in a state where it is able to be
@@ -393,6 +494,16 @@ where
     }
 }
 
+#[cfg(feature = "core-fmt")]
+impl<A> fmt::Display for OpenExpression<'_, A>
+where
+    A: Allocator + Clone,
+{
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        write!(f, "{}", self.0)
+    }
+}
+
 /// Iterator over the children of a [`Node`].
 pub struct Children<'a, 'id> {
     /// The node we are currently at.
@@ -674,6 +785,19 @@ pub enum ScalarValue {
     SignedIntegerRational(Rational<i64>),
 }
 
+#[cfg(feature = "core-fmt")]
+impl fmt::Display for ScalarValue {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        match self {
+            Self::UnsignedInteger(v) => write!(f, "{}", v),
+            Self::Integer(v) => write!(f, "{}", v),
+            Self::Float(v) => write!(f, "{}", v),
+            Self::UnsignedIntegerRational(v) => write!(f, "{}", v),
+            Self::SignedIntegerRational(v) => write!(f, "{}", v),
+        }
+    }
+}
+
 /// Representation of an error that occurred within [`Expression`].
 #[non_exhaustive]
 #[derive(Eq, PartialEq)]
@@ -697,12 +821,12 @@ impl From<TryReserveError> for Error {
 
 #[cfg(feature = "core-fmt")]
 impl fmt::Display for Error {
-    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
         let reason = match self {
             Error::TryReserveError(_) => "unable to allocate memory",
             Error::InvalidChild => "invalid child node",
         };
-        fmt.write_str(reason)
+        f.write_str(reason)
     }
 }
 
@@ -763,7 +887,10 @@ mod test {
             &[sin, sum]
         );
 
-        expr.close();
+        expr.set_root_node(product);
+        let expr = expr.close();
+
+        println!("{}", expr);
     }
 
     /// Reference state machine for [`Expression`]s.
@@ -874,7 +1001,6 @@ mod test {
             state: &Self::State,
         ) -> BoxedStrategy<Self::Transition> {
             if state.is_open && state.valid_indices.is_empty() {
-                let n_valid = state.valid_indices.len();
                 prop_oneof![
                     1 => Just(ExpressionTransition::Close),
                     1 => Just(ExpressionTransition::TakeRoot),
@@ -1337,7 +1463,10 @@ mod test {
     prop_state_machine! {
         #![proptest_config(ProptestConfig {
             // allow more rejects to make shrinking results better
-            max_global_rejects: 1_000_000,
+            max_global_rejects: u32::MAX,
+
+            // allow more shrinking iterations to improve the results
+            max_shrink_iters: u32::MAX,
 
             // disable failure persistence so miri works
             #[cfg(miri)]
@@ -1346,6 +1475,6 @@ mod test {
         })]
 
         #[test]
-        fn matches_state_machine(sequential 1..100 => ExpressionWrapper);
+        fn matches_state_machine(sequential 1..1_000 => ExpressionWrapper);
     }
 }
diff --git a/crates/core/src/hive/mod.rs b/crates/core/src/hive/mod.rs
index ad14725..3f83099 100644
--- a/crates/core/src/hive/mod.rs
+++ b/crates/core/src/hive/mod.rs
@@ -297,8 +297,8 @@ pub enum Error {
 
 #[cfg(feature = "core-fmt")]
 impl fmt::Display for Error {
-    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
-        fmt.write_str("memory allocation failed")?;
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        f.write_str("memory allocation failed")?;
         let reason = match self {
             Error::CapacityOverflow => {
                 " because the computed capacity exceeded the hive's maximum"
@@ -307,7 +307,7 @@ impl fmt::Display for Error {
                 " because the memory allocator returned an error"
             }
         };
-        fmt.write_str(reason)
+        f.write_str(reason)
     }
 }
 
diff --git a/crates/core/src/numerics/rational.rs b/crates/core/src/numerics/rational.rs
index e814dbd..173e6f6 100644
--- a/crates/core/src/numerics/rational.rs
+++ b/crates/core/src/numerics/rational.rs
@@ -106,6 +106,16 @@ where
     }
 }
 
+#[cfg(feature = "core-fmt")]
+impl<T> fmt::Display for Rational<T>
+where
+    T: Integer + fmt::Display,
+{
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        write!(f, "{}/{}", self.numerator, self.denominator)
+    }
+}
+
 /// Representation of an error that occurred within [`Rational`].
 #[non_exhaustive]
 #[derive(Eq, PartialEq)]
@@ -117,12 +127,12 @@ pub enum Error {
 
 #[cfg(feature = "core-fmt")]
 impl fmt::Display for Error {
-    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
-        fmt.write_str("rational number construction failed")?;
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        f.write_str("rational number construction failed")?;
         let reason = match self {
             Error::ZeroDenominator => " because the denominator was zero",
         };
-        fmt.write_str(reason)
+        f.write_str(reason)
     }
 }