From 83751efd734999fc11316a66317250ca53e76726 Mon Sep 17 00:00:00 2001 From: superwhiskers Date: Wed, 27 Aug 2025 14:41:19 -0500 Subject: initial expression implementation Change-Id: I6a6a69640c133bce112891bba09033b08e7c0dec --- crates/core/src/expressions/mod.rs | 1351 ++++++++++++++++++++++++++++++++++ crates/core/src/hive/group.rs | 24 +- crates/core/src/hive/mod.rs | 82 ++- crates/core/src/hive/skipfield.rs | 17 +- crates/core/src/lib.rs | 7 +- crates/core/src/numerics/mod.rs | 102 +++ crates/core/src/numerics/rational.rs | 147 ++++ 7 files changed, 1720 insertions(+), 10 deletions(-) create mode 100644 crates/core/src/expressions/mod.rs create mode 100644 crates/core/src/numerics/mod.rs create mode 100644 crates/core/src/numerics/rational.rs (limited to 'crates/core/src') diff --git a/crates/core/src/expressions/mod.rs b/crates/core/src/expressions/mod.rs new file mode 100644 index 0000000..c5abebf --- /dev/null +++ b/crates/core/src/expressions/mod.rs @@ -0,0 +1,1351 @@ +//! Common representation of mathematical expressions. + +use alloc::{ + alloc::{Allocator, Global}, + collections::TryReserveError, + string::String, + vec::Vec, +}; +use core::{hint, iter::TrustedLen}; +use generativity::{Guard, Id}; + +use crate::numerics::rational::Rational; + +#[cfg(feature = "core-fmt")] +use core::fmt; + +#[cfg(feature = "core-error")] +use core::error; + +//TODO: consider topological sorting of the expression's internals as this +// may be better for performance. +//TODO: consider garbage collection of unreachable nodes (and their +// children). +//TODO: if we add garbage collection, support marking a node for deletion +// when garbage collection is performed. +//TODO: consider renaming `OpenExpression` to something else. +//TODO: add more thorough documentation, along the lines of the standard +// library's `Vec`, given how this is a very fundamental type. + +/// Representation of a mathematical expression. +/// +/// Components of the expression are maintained in a vector where they are +/// linked to each other by their indices. Indices always point to another +/// existing node. +/// +/// # Guarantees +/// +/// - The expression contains no cycles. +/// - All references to other nodes are valid. +/// - All references to interned strings are valid. +/// - No strings are interned more than once. +#[derive(Clone)] +#[cfg_attr(feature = "core-fmt", derive(Debug))] +pub struct Expression +where + A: Allocator + Clone, +{ + /// List of nodes linked to each other by their indices. + pub(crate) inner: Vec, + + /// List of interned strings. + pub(crate) interns: Vec, + + /// Root node of the expression. + pub(crate) root: Option, +} + +impl Expression { + /// Constructs an empty expression. + pub fn new() -> Self { + Self::new_in(Global) + } +} + +impl Expression +where + A: Allocator + Clone, +{ + /// Constructs an empty expression using the provided [`Allocator`]. + pub fn new_in(alloc: A) -> Self { + Self { + inner: Vec::new_in(alloc.clone()), + interns: Vec::new_in(alloc), + root: None, + } + } + + /// Clears the expression, removing all nodes and associated information. + pub fn clear(&mut self) { + self.inner.clear(); + self.interns.clear(); + self.root = None; + } + + /// Opens up the expression for additive operations and indexing. + /// + /// While in this state, no destructive operations may be done, such as + /// garbage collection and sorting of the node list. + pub fn open<'id>(self, guard: Guard<'id>) -> OpenExpression<'id, A> { + OpenExpression(self, guard.into()) + } + + /// Inserts a [`NodeInternal`] into the node list and returns the index. + /// + /// # Errors + /// + /// Returns an error if storage was unable to be reserved for the new + /// node. + fn insert_node( + &mut self, + node: NodeInternal, + ) -> Result { + self.inner.try_reserve(1)?; + //SAFETY: we just reserved space for this element + unsafe { self.inner.push_within_capacity(node).unwrap_unchecked() } + Ok(NodeIndexInternal(self.inner.len() - 1)) + } + + /// Inserts a [`&str`] into the intern list if it doesn't already exist + /// and returns the index. + /// + /// # Errors + /// + /// Returns an error if storage was unable to be reserved for the new + /// string. + fn insert_string( + &mut self, + string: impl AsRef, + ) -> Result { + if let Some(i) = + self.interns.iter().position(|s| s == string.as_ref()) + { + return Ok(StringIndexInternal(i)); + } + + self.interns.try_reserve(1)?; + //SAFETY: we just reserved space for this element + unsafe { + self.interns + .push_within_capacity(string.as_ref().into()) + .unwrap_unchecked() + } + Ok(StringIndexInternal(self.interns.len() - 1)) + } + + /// Returns the number of nodes within this expression. + pub fn len(&self) -> usize { + self.inner.len() + } + + /// Indicates if the expression has no nodes. + pub fn is_empty(&self) -> bool { + self.inner.is_empty() + } +} + +impl Default for Expression { + fn default() -> Self { + Self::new() + } +} + +/// Representation of a mathematical expression. +/// +/// This is a variant of [`Expression`] in a state where it is able to be +/// written to and indexed. +#[derive(Clone)] +#[cfg_attr(feature = "core-fmt", derive(Debug))] +pub struct OpenExpression<'id, A>(Expression, Id<'id>) +where + A: Allocator + Clone; + +impl<'id, A> OpenExpression<'id, A> +where + A: Allocator + Clone, +{ + /// Adds a new scalar value to the expression. + /// + /// # Errors + /// + /// Returns an error if storage was unable to be reserved for the new + /// scalar value. + pub fn insert_scalar( + &mut self, + value: impl Into, + ) -> Result, Error> { + Ok(( + self.0.insert_node(NodeInternal::Scalar(value.into()))?, + self.1, + ) + .into()) + } + + /// Adds a new variable to the expression. + /// + /// # Errors + /// + /// Returns an error if storage was unable to be reserved for the new + /// variable. + pub fn insert_variable( + &mut self, + name: impl AsRef, + ) -> Result, Error> { + let name = self.0.insert_string(name)?; + Ok( + (self.0.insert_node(NodeInternal::Variable(name))?, self.1) + .into(), + ) + } + + /// Adds a new operation to the expression. + /// + /// # Errors + /// + /// Returns an error if storage was unable to be reserved for the new + /// operation or if one of the children of the node would have been a + /// [`Node::Join`]. + pub fn insert_operation( + &mut self, + operator: impl AsRef, + arguments: Arguments, + ) -> Result, Error> + where + Arguments: IntoIterator>, + Arguments::IntoIter: TrustedLen, + { + let operator = self.0.insert_string(operator)?; + let mut arguments = arguments.into_iter(); + let n_arguments = arguments.size_hint().0; + if n_arguments == 2 { + //SAFETY: the iterator is required to implement `TrustedLen` + let first = unsafe { arguments.next().unwrap_unchecked() }; + + //SAFETY: the iterator is required to implement `TrustedLen` + let second = unsafe { arguments.next().unwrap_unchecked() }; + + if matches!( + //SAFETY: lifetime branding ensures this index is valid + unsafe { self.0.inner.get_unchecked(first.0) }, + NodeInternal::Join(_, _) + ) || matches!( + //SAFETY: lifetime branding ensures this index is valid + unsafe { self.0.inner.get_unchecked(second.0) }, + NodeInternal::Join(_, _) + ) { + return Err(Error::InvalidChild); + } + + Ok(( + self.0.insert_node(NodeInternal::BinaryOperation( + operator, + first.into(), + second.into(), + ))?, + self.1, + ) + .into()) + } else { + let number_of_joins = n_arguments.saturating_sub(1); + self.0.inner.try_reserve(number_of_joins + 1)?; + + //SAFETY: the iterator is required to implement `TrustedLen` so + // we will have reserved enough space for these in addition + // to guaranteeing the noted elements exist + // + // additionally, because of our lifetime branding, all + // indices referenced by lifetime-branded node indices are + // valid + unsafe { + self.0 + .inner + .push_within_capacity(match n_arguments { + 0 => NodeInternal::Operation(operator, None), + 1 => { + let argument = + arguments.next().unwrap_unchecked(); + if matches!( + self.0.inner.get_unchecked(argument.0), + NodeInternal::Join(_, _) + ) { + return Err(Error::InvalidChild); + } + + NodeInternal::Operation( + operator, + Some(argument.into()), + ) + } + //NOTE: we know that we will be adding at least one + // more slot to hold an `NodeInternal::Join` + _ => NodeInternal::Operation( + operator, + Some(NodeIndexInternal(self.0.inner.len() + 1)), + ), + }) + .unwrap_unchecked() + } + let operation_index = NodeIndexInternal(self.0.inner.len() - 1); + + //NOTE: we are exclusively inserting `NodeInternal::Join`s now + for (i, argument) in arguments.enumerate() { + if matches!( + //SAFETY: lifetime branding ensures this index is valid + unsafe { self.0.inner.get_unchecked(argument.0) }, + NodeInternal::Join(_, _) + ) { + //NOTE: leaving the inner vector in this state is not + // invalid as nothing will be able to reference any + // of the slots allocated already + // + // when it is added, garbage collection will clean up + // after this if it happens + return Err(Error::InvalidChild); + } + + if i != n_arguments - 1 { + //SAFETY: we know that we have reserved enough space for + // each join by this point + unsafe { + self.0 + .inner + .push_within_capacity(NodeInternal::Join( + argument.into(), + NodeIndexInternal(self.0.inner.len() + 1), + )) + .unwrap_unchecked() + } + } else { + let last_index = self.0.inner.len() - 1; + + if let &mut NodeInternal::Join(_, ref mut second) = + //SAFETY: we know that if we are at this point, + // the current last element exists + unsafe { + self.0.inner.get_unchecked_mut(last_index) + } + { + *second = argument.into(); + } else { + //SAFETY: we know by this point that this last + // element must be `NodeInternal::Join` + unsafe { hint::unreachable_unchecked() } + } + } + } + + Ok((operation_index, self.1).into()) + } + } + + /// Sets the root node of the expression. + pub fn set_root_node(&mut self, node: NodeIndex<'id>) { + self.0.root = Some(node.into()); + } + + /// Closes the expression to additive operations and being indexed. + pub fn close(self) -> Expression { + self.0 + } + + /// Returns the number of nodes within this expression. + pub fn len(&self) -> usize { + self.0.inner.len() + } + + /// Indicates if the expression has no nodes. + pub fn is_empty(&self) -> bool { + self.0.inner.is_empty() + } + + /// Gets the root node of the expression, if it exists. + pub fn root_node(&self) -> Option> { + Some((self.0.root?, self.1).into()) + } + + /// Takes the expression's root node. + pub fn take_root_node(&mut self) -> Option> { + Some((self.0.root.take()?, self.1).into()) + } + + /// Retrieves the indicated node from the expression. + pub fn node(&self, index: NodeIndex<'id>) -> Node<'id> { + //SAFETY: by the lifetime, this is guaranteed to be a valid index + ( + unsafe { self.0.inner.get_unchecked(index.0) }.clone(), + self.1, + ) + .into() + } + + /// Retrieves the indicated string from the expression. + pub fn string(&self, index: StringIndex<'id>) -> &str { + //SAFETY: by the lifetime, this is guaranteed to be a valid index + unsafe { self.0.interns.get_unchecked(index.0) } + } + + /// Returns an iterator over the children of the node. + pub fn children_of<'a>( + &'a self, + index: NodeIndex<'id>, + ) -> Children<'a, 'id> { + Children::new(&self.0.inner, index) + } +} + +/// Iterator over the children of a [`Node`]. +pub struct Children<'a, 'id> { + /// The node we are currently at. + /// + /// If the node given to the constructor was a [`Node::Scalar`], + /// [`Node::Variable`], or a [`Node::Operation`] with `None` as the + /// second parameter, this must be `None`, and hence this may never be + /// referencing a node of those kinds. + current: Option, + + /// Whether or not we are at the "second" part of the node. + /// + /// This is only relevant for [`Node::BinaryOperation`] as it is not + /// clear which of the children are being referenced. + second: bool, + + /// Reference to the internal storage of the [`Expression`]. + inner: &'a [NodeInternal], + + /// Lifetime tying it to the open state of the [`Expression`]. + id: Id<'id>, +} + +impl<'a, 'id> Children<'a, 'id> { + /// Creates a new iterator over the children of a [`Node`]. + fn new(inner: &'a [NodeInternal], index: NodeIndex<'id>) -> Self { + let id = index.1; + if matches!( + //SAFETY: by the lifetime, this is guaranteed to be a valid + // index + unsafe { inner.get_unchecked(index.0) }, + NodeInternal::Scalar(_) + | NodeInternal::Variable(_) + | NodeInternal::Operation(_, None) + ) { + Self { + current: None, + second: false, + inner, + id, + } + } else { + Self { + current: Some(index.into()), + second: false, + inner, + id, + } + } + } +} + +impl<'a, 'id> Iterator for Children<'a, 'id> { + type Item = NodeIndex<'id>; + + fn next(&mut self) -> Option { + if let Some(current) = self.current { + //SAFETY: by the lifetime, this is guaranteed to be a valid index + let current = unsafe { self.inner.get_unchecked(current.0) }; + Some( + ( + match *current { + NodeInternal::Operation(_, Some(index)) => { + //SAFETY: all internal indices are guaranteed to + // be valid + let next = + unsafe { self.inner.get_unchecked(index.0) }; + + if let NodeInternal::Join(child_index, _) = *next + { + self.current = Some(index); + child_index + } else { + self.current = None; + index + } + } + NodeInternal::Join(_, index) => { + //SAFETY: all internal indices are guaranteed to + // be valid + let next = + unsafe { self.inner.get_unchecked(index.0) }; + + if let NodeInternal::Join(child_index, _) = *next + { + self.current = Some(index); + child_index + } else { + self.current = None; + index + } + } + NodeInternal::BinaryOperation(_, index, _) + if !self.second => + { + self.second = true; + index + } + NodeInternal::BinaryOperation(_, _, index) => { + self.current = None; + index + } + _ => unreachable!(), + }, + self.id, + ) + .into(), + ) + } else { + None + } + } +} + +/// Index into the node vector of [`Expression`]. +#[repr(transparent)] +#[derive(Eq, PartialEq, Copy, Clone, Hash)] +#[cfg_attr(feature = "core-fmt", derive(Debug))] +//NOTE: false positive, see https://github.com/CAD97/generativity/issues/13 +#[allow(repr_transparent_external_private_fields)] +pub struct NodeIndex<'id>(usize, Id<'id>); + +impl<'id> From<(NodeIndexInternal, Id<'id>)> for NodeIndex<'id> { + fn from((index, id): (NodeIndexInternal, Id<'id>)) -> Self { + Self(index.0, id) + } +} + +/// Index into the node vector of [`Expression`]. +/// +/// This is only for internal use to avoid the presence of explicit lifetimes +/// on [`Expression`] members. +#[repr(transparent)] +#[derive(Eq, PartialEq, Copy, Clone, Hash)] +#[cfg_attr(feature = "core-fmt", derive(Debug))] +pub(crate) struct NodeIndexInternal(usize); + +impl From> for NodeIndexInternal { + fn from(NodeIndex(index, _): NodeIndex<'_>) -> Self { + Self(index) + } +} + +/// Index into the string intern vector of [`Expression`]. +#[repr(transparent)] +#[derive(Eq, PartialEq, Copy, Clone, Hash)] +#[cfg_attr(feature = "core-fmt", derive(Debug))] +//NOTE: false positive, see https://github.com/CAD97/generativity/issues/13 +#[allow(repr_transparent_external_private_fields)] +pub struct StringIndex<'id>(usize, Id<'id>); + +impl<'id> From<(StringIndexInternal, Id<'id>)> for StringIndex<'id> { + fn from((index, id): (StringIndexInternal, Id<'id>)) -> Self { + Self(index.0, id) + } +} + +/// Index into the string intern vector of [`Expression`]. +/// +/// This is only for internal use to avoid the presence of explicit lifetimes +/// on [`Expression`] members. +#[repr(transparent)] +#[derive(Eq, PartialEq, Copy, Clone, Hash)] +#[cfg_attr(feature = "core-fmt", derive(Debug))] +pub(crate) struct StringIndexInternal(usize); + +impl From> for StringIndexInternal { + fn from(StringIndex(index, _): StringIndex<'_>) -> Self { + Self(index) + } +} + +/// Internal representation of an expression node. +#[non_exhaustive] +#[derive(PartialEq, Clone)] +#[cfg_attr(feature = "core-fmt", derive(Debug))] +pub(crate) enum NodeInternal { + /// Operation over other nodes. + /// + /// The first component indicates the kind of operation and the second is + /// the first argument. Several arguments may be represented through the + /// use of `Join`. + Operation(StringIndexInternal, Option), + + /// Binary operation. + /// + /// This is special cased due to how common they are. `Join`s are not + /// permitted to be referenced by the indices here. + BinaryOperation( + StringIndexInternal, + NodeIndexInternal, + NodeIndexInternal, + ), + + /// Joins one `NodeInternal` with another in a sequence, with the first + /// coming before the second. + /// + /// The first `NodeIndexInternal` must not be a `Join`. + Join(NodeIndexInternal, NodeIndexInternal), + + /// Scalar value. + Scalar(ScalarValue), + + /// Variable. + Variable(StringIndexInternal), +} + +/// Node in an expression. +#[non_exhaustive] +#[derive(PartialEq, Clone)] +#[cfg_attr(feature = "core-fmt", derive(Debug))] +pub enum Node<'id> { + /// Operation over other nodes. + /// + /// The first component indicates the kind of operation and the second is + /// the first argument. Several arguments may be represented through the + /// use of `Join`. + Operation(StringIndex<'id>, Option>), + + /// Binary operation. + /// + /// This is special cased due to how common they are. `Join`s are not + /// permitted to be referenced by the indices here. + BinaryOperation(StringIndex<'id>, NodeIndex<'id>, NodeIndex<'id>), + + /// Joins one `NodeInternal` with another in a sequence, with the first + /// coming before the second. + /// + /// The first `NodeIndexInternal` must not be a `Join`. + Join(NodeIndex<'id>, NodeIndex<'id>), + + /// Scalar value. + Scalar(ScalarValue), + + /// Variable. + Variable(StringIndex<'id>), +} + +impl<'id> From<(NodeInternal, Id<'id>)> for Node<'id> { + fn from((node, id): (NodeInternal, Id<'id>)) -> Self { + match node { + NodeInternal::Operation(s, n) => { + Node::Operation((s, id).into(), n.map(|n| (n, id).into())) + } + NodeInternal::BinaryOperation(s, n1, n2) => { + Node::BinaryOperation( + (s, id).into(), + (n1, id).into(), + (n2, id).into(), + ) + } + NodeInternal::Join(n1, n2) => { + Node::Join((n1, id).into(), (n2, id).into()) + } + NodeInternal::Scalar(s) => Node::Scalar(s), + NodeInternal::Variable(s) => Node::Variable((s, id).into()), + } + } +} + +/// Scalar value. +#[non_exhaustive] +#[derive(PartialEq, Clone)] +#[cfg_attr(feature = "core-fmt", derive(Debug))] +pub enum ScalarValue { + /// Unsigned integer. + UnsignedInteger(u64), + + /// Signed integer. + Integer(i64), + + /// Floating point value. + Float(f64), + + /// Rational number with unsigned integer components. + UnsignedIntegerRational(Rational), + + /// Rational number with signed integer components. + SignedIntegerRational(Rational), +} + +/// Representation of an error that occurred within [`Expression`]. +#[non_exhaustive] +#[derive(Eq, PartialEq)] +#[cfg_attr(feature = "core-fmt", derive(Debug))] +pub enum Error { + /// Memory reservation error. + TryReserveError(TryReserveError), + + /// Invalid child node. + /// + /// This occurs when you attempt to pass an index that points to a + /// [`Node::Join`] as a child of an operation. + InvalidChild, +} + +impl From for Error { + fn from(e: TryReserveError) -> Self { + Self::TryReserveError(e) + } +} + +#[cfg(feature = "core-fmt")] +impl fmt::Display for Error { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + let reason = match self { + Error::TryReserveError(_) => "unable to allocate memory", + Error::InvalidChild => "invalid child node", + }; + fmt.write_str(reason) + } +} + +#[cfg(feature = "core-error")] +impl error::Error for Error {} + +#[cfg(all(test, feature = "core-error"))] +mod test { + #![allow(clippy::expect_used)] + + use super::*; + + use generativity::make_guard; + use proptest::prelude::*; + use proptest_state_machine::{ + ReferenceStateMachine, StateMachineTest, prop_state_machine, + }; + use std::{assert_matches::assert_matches, collections::HashMap}; + + #[test] + fn simple() { + make_guard!(g); + let mut expr = Expression::new().open(g); + + let a = expr + .insert_variable("apples") + .expect("unable to insert a variable"); + assert_eq!(expr.children_of(a).collect::>().as_slice(), &[]); + + let b = expr + .insert_variable("bananas") + .expect("unable to insert a variable"); + let c = expr + .insert_variable("oranges") + .expect("unable to insert a variable"); + + let sin = expr + .insert_operation("sin", [a]) + .expect("unable to insert an operation"); + assert_eq!( + expr.children_of(sin).collect::>().as_slice(), + &[a] + ); + + let sum = expr + .insert_operation("sum", [a, b, c]) + .expect("unable to insert an operation"); + assert_eq!( + expr.children_of(sum).collect::>().as_slice(), + &[a, b, c] + ); + + let product = expr + .insert_operation("*", [sin, sum]) + .expect("unable to insert an operation"); + assert_eq!( + expr.children_of(product).collect::>().as_slice(), + &[sin, sum] + ); + + expr.close(); + } + + /// Reference state machine for [`Expression`]s. + /// + /// Models [`Expression`]s as a root node index, a vector of nodes, a + /// mapping of parent node indices to a vector of child node indices, + /// a vector of valid indices, and whether or not it is currently open. + struct ExpressionReference; + + /// State used by the reference state machine. + #[derive(Clone, Debug, Default)] + struct ExpressionReferenceState { + /// Index of the root node in the node vector. + pub root: Option, + + /// Vector of nodes in the expression. + pub nodes: Vec, + + /// Mapping of parent node indices to child node indices. + pub children: HashMap>, + + /// The current list of valid indices. + pub valid_indices: Vec, + + /// Whether or not the expression is open. + /// + /// If it is open, then additive operations and indexing is allowed. + /// Otherwise, operations that may invalidate indices such as node + /// removal, sorting, etc are allowed. + pub is_open: bool, + } + + /// Simplified [`NodeInternal`] representation for the [`HashMap`] backed + /// graph. + #[derive(Clone, Debug)] + enum NodeRep { + /// Operation over other nodes. + Operation(String), + + /// Scalar value. + Scalar(ScalarValue), + + /// Variable. + Variable(String), + } + + /// Possible transitions for the state machine. + #[derive(Clone, Debug)] + enum ExpressionTransition { + /// Open the expression. + Open, + + /// Add a new operation to the expression. + NewOperation(String, Vec), + + /// Add a new scalar to the expression. + NewScalar(ScalarValue), + + /// Add a new variable to the expression. + NewVariable(String), + + /// Set the root node of the expression. + SetRoot(usize), + + /// Remove the root of the expression. + TakeRoot, + + /// Acquire a valid index to the root of expression. + AcquireRoot, + + /// Acquire valid indices to the children of the given node. + AcquireChildren(usize), + + /// Close the expression, invalidating all currently held indices. + Close, + + /// Remove all nodes from the expression. + Clear, + } + + impl ExpressionTransition { + /// Indicates if this transition requires that the expression be open. + fn needs_open(&self) -> bool { + matches!( + self, + ExpressionTransition::NewOperation(_, _) + | ExpressionTransition::NewScalar(_) + | ExpressionTransition::NewVariable(_) + | ExpressionTransition::SetRoot(_) + | ExpressionTransition::TakeRoot + | ExpressionTransition::AcquireRoot + | ExpressionTransition::AcquireChildren(_) + | ExpressionTransition::Close + ) + } + } + + impl ReferenceStateMachine for ExpressionReference { + type State = ExpressionReferenceState; + type Transition = ExpressionTransition; + + fn init_state() -> BoxedStrategy { + //TODO: randomly initialize state + Just(Self::State::default()).boxed() + } + + fn transitions( + state: &Self::State, + ) -> BoxedStrategy { + if state.is_open && state.valid_indices.is_empty() { + let n_valid = state.valid_indices.len(); + prop_oneof![ + 1 => Just(ExpressionTransition::Close), + 1 => Just(ExpressionTransition::TakeRoot), + 5 => Just(ExpressionTransition::AcquireRoot), + //NOTE: what a scalar value actually is has zero + // bearing on anything so we can just yield a static + // value here without any problems. in the future we + // may have something like a common subexpression + // elimination feature, which would make varying it + // valuable, but for now, we have no reason to do + // anything other than adding a static scalar. this + // is different for variables because string + // interning exists, so varying those does actually + // have an effect + 7 => Just(ExpressionTransition::NewScalar( + ScalarValue::UnsignedInteger(1) + )), + // 702 possible identifiers is probably enough + 7 => "[a-z]{1,2}" + .prop_map(ExpressionTransition::NewVariable), + ] + .boxed() + } else if state.is_open { + let n_valid = state.valid_indices.len(); + prop_oneof![ + 1 => Just(ExpressionTransition::Close), + 1 => Just(ExpressionTransition::TakeRoot), + 5 => Just(ExpressionTransition::AcquireRoot), + //NOTE: what a scalar value actually is has zero + // bearing on anything so we can just yield a static + // value here without any problems. in the future we + // may have something like a common subexpression + // elimination feature, which would make varying it + // valuable, but for now, we have no reason to do + // anything other than adding a static scalar. this + // is different for variables because string + // interning exists, so varying those does actually + // have an effect + 7 => Just(ExpressionTransition::NewScalar( + ScalarValue::UnsignedInteger(1) + )), + // 702 possible identifiers is probably enough + 7 => "[a-z]{1,2}" + .prop_map(ExpressionTransition::NewVariable), + 9 => prop::sample::subsequence( + state.valid_indices.clone(), + 0..=n_valid + ) + //NOTE: 702 possible identifiers is + // probably enough + .prop_flat_map(|v| ("[A-Z]{1,2}", Just(v))) + .prop_map(|(s, v)| + ExpressionTransition::NewOperation(s, v) + ), + 1 => prop::sample::select(state.valid_indices.clone()) + .prop_map(ExpressionTransition::SetRoot), + 5 => prop::sample::select(state.valid_indices.clone()) + .prop_map(ExpressionTransition::AcquireChildren), + ] + .boxed() + } else { + prop_oneof![ + 1 => Just(ExpressionTransition::Clear), + 9 => Just(ExpressionTransition::Open), + ] + .boxed() + } + } + + fn preconditions( + state: &Self::State, + transition: &Self::Transition, + ) -> bool { + if state.is_open ^ transition.needs_open() { + return false; + } + + match transition { + ExpressionTransition::NewOperation(_, children) => { + for c in children { + if *c >= state.nodes.len() + || !state.valid_indices.contains(c) + { + return false; + } + } + } + ExpressionTransition::SetRoot(node) + | ExpressionTransition::AcquireChildren(node) => { + if *node >= state.nodes.len() + || !state.valid_indices.contains(node) + { + return false; + } + } + _ => (), + } + true + } + + fn apply( + mut state: Self::State, + transition: &Self::Transition, + ) -> Self::State { + match transition { + ExpressionTransition::Open => state.is_open = true, + ExpressionTransition::Clear => { + state.nodes.clear(); + state.children.clear(); + state.root = None; + } + ExpressionTransition::Close => { + state.valid_indices.clear(); + state.is_open = false; + } + ExpressionTransition::SetRoot(index) => { + state.root = Some(*index); + } + ExpressionTransition::TakeRoot => { + state.root = None; + } + ExpressionTransition::NewVariable(s) => { + state.nodes.push(NodeRep::Variable(s.clone())); + let i = state.nodes.len() - 1; + state.valid_indices.push(i); + state.children.insert(i, Vec::new()); + } + ExpressionTransition::NewScalar(v) => { + state.nodes.push(NodeRep::Scalar(v.clone())); + let i = state.nodes.len() - 1; + state.valid_indices.push(i); + state.children.insert(i, Vec::new()); + } + ExpressionTransition::AcquireRoot => { + if let Some(root) = state.root { + state.valid_indices.push(root); + } + } + ExpressionTransition::AcquireChildren(i) => { + state.valid_indices.extend(&state.children[i]); + } + ExpressionTransition::NewOperation(s, c) => { + state.nodes.push(NodeRep::Operation(s.clone())); + let i = state.nodes.len() - 1; + state.children.insert(i, c.clone()); + state.valid_indices.push(i); + } + } + state + } + } + + /// Wrapper around an [`Expression`] indicating if it is open or not. + enum ExpressionWrapper<'id, A = Global> + where + A: Allocator + Clone, + { + /// Closed expression.. + Closed(Expression), + + /// Open expression and a mapping from the state machine's indices to + /// the [`Expression`]'s. + Open(HashMap>, OpenExpression<'id, A>), + } + + impl<'id> StateMachineTest for ExpressionWrapper<'id> { + type SystemUnderTest = Self; + type Reference = ExpressionReference; + + fn init_test( + _ref_state: &::State, + ) -> Self::SystemUnderTest { + //TODO: randomly initialize state in the reference and replicate + // it over here. + ExpressionWrapper::Closed(Expression::new()) + } + + fn apply( + mut state: Self::SystemUnderTest, + ref_state: &::State, + transition: ExpressionTransition, + ) -> Self::SystemUnderTest { + match transition { + ExpressionTransition::Open => { + if let ExpressionWrapper::Closed(closed_state) = state { + //SAFETY: unfortunately, we have to do this, + // otherwise this test cannot work out due to + // lifetime issues + let guard = unsafe { Guard::new(Id::new()) }; + state = ExpressionWrapper::Open( + HashMap::new(), + closed_state.open(guard), + ); + } else { + unreachable!(); + } + } + ExpressionTransition::Clear => { + if let ExpressionWrapper::Closed(closed_state) = + &mut state + { + closed_state.clear(); + + //NOTE: post-conditions + assert!( + closed_state.inner.is_empty(), + "inner node vector should be empty after clearing" + ); + assert!( + closed_state.interns.is_empty(), + "intern vector should be empty after clearing" + ); + assert!( + closed_state.root.is_none(), + "root node should not exist after clearing" + ); + } else { + unreachable!(); + } + } + ExpressionTransition::Close => { + if let ExpressionWrapper::Open(_, open_state) = state { + state = ExpressionWrapper::Closed(open_state.close()); + } else { + unreachable!(); + } + } + ExpressionTransition::SetRoot(index) => { + if let ExpressionWrapper::Open(mapping, open_state) = + &mut state + { + open_state.set_root_node(mapping[&index]); + + //NOTE: post-conditions + assert_eq!( + open_state.0.root, + Some(NodeIndexInternal(mapping[&index].0)), + "root should be set to the given index after setting it" + ); + } else { + unreachable!(); + } + } + ExpressionTransition::TakeRoot => { + if let ExpressionWrapper::Open(_, open_state) = &mut state + { + let root_before = open_state.0.root; + let taken_root = open_state + .take_root_node() + .map(|r| NodeIndexInternal(r.0)); + + //NOTE: post-conditions + assert!( + open_state.0.root.is_none(), + "root should be None after taking it" + ); + assert_eq!( + root_before, taken_root, + "root before taking and returned root value from taking should be equal" + ); + } else { + unreachable!(); + } + } + ExpressionTransition::NewVariable(s) => { + if let ExpressionWrapper::Open(mapping, open_state) = + &mut state + { + let var = open_state + .insert_variable(&s) + .expect("inserting a variable should not fail"); + mapping.insert(ref_state.nodes.len() - 1, var); + + let position = open_state.0.interns.iter().position(|t| t == s.as_str()).expect("the interns list should contain the name of the inserted variable"); + assert_matches!( + open_state.0.inner[var.0], + NodeInternal::Variable(position), + "the variable node should point to the position of the equivalent string in the intern list" + ); + } else { + unreachable!(); + } + } + ExpressionTransition::NewScalar(v) => { + if let ExpressionWrapper::Open(mapping, open_state) = + &mut state + { + let scalar = open_state + .insert_scalar(v) + .expect("inserting a scalar should not fail"); + mapping.insert(ref_state.nodes.len() - 1, scalar); + assert_matches!( + &open_state.0.inner[scalar.0], + NodeInternal::Scalar(v), + "the scalar node should contain the same value as was inserted" + ); + } else { + unreachable!(); + } + } + ExpressionTransition::NewOperation(s, c) => { + if let ExpressionWrapper::Open(mapping, open_state) = + &mut state + { + let operation = open_state + .insert_operation( + &s, + c.iter().copied().map(|i| mapping[&i]), + ) + .expect("inserting an operation should not fail"); + mapping.insert(ref_state.nodes.len() - 1, operation); + + let position = open_state.0.interns.iter().position(|t| t == s.as_str()).expect("the interns list should contain the name of the inserted operation"); + if c.len() == 2 { + assert_matches!( + open_state.0.inner[operation.0], + NodeInternal::BinaryOperation(position, _, _), + "the operation node should point to the position of the equivalent string in the intern list", + ); + } else { + assert_matches!( + open_state.0.inner[operation.0], + NodeInternal::Operation(position, _), + "the operation node should point to the position of the equivalent string in the intern list", + ); + } + //TODO: validate join nodes, child node equality + } else { + unreachable!(); + } + } + ExpressionTransition::AcquireRoot => { + if let ExpressionWrapper::Open(mapping, open_state) = + &mut state + { + let root = open_state.root_node(); + assert_eq!( + ref_state.root.is_some(), + root.is_some(), + "both the reference state and system under test must match" + ); + + if let Some(root) = root + && let Some(ref_root) = ref_state.root + { + mapping.insert(ref_root, root); + } + } else { + unreachable!(); + } + } + ExpressionTransition::AcquireChildren(i) => { + if let ExpressionWrapper::Open(mapping, open_state) = + &mut state + { + mapping.extend( + ref_state.children[&i] + .iter() + .copied() + .zip(open_state.children_of(mapping[&i])), + ); + } else { + unreachable!(); + } + } + } + state + } + + fn check_invariants( + state: &Self::SystemUnderTest, + ref_state: &::State, + ) { + let expr = match state { + ExpressionWrapper::Open(_, expr) => &expr.0, + ExpressionWrapper::Closed(expr) => expr, + }; + + let mut deduplicated_interns = expr.interns.clone(); + deduplicated_interns.sort(); + deduplicated_interns.dedup(); + assert!( + deduplicated_interns.len() == expr.interns.len(), + "there must be no duplicated string interns" + ); + + for node in &expr.inner { + match node { + NodeInternal::Operation(i, Some(n)) => { + assert!( + i.0 < expr.interns.len(), + "intern index must be valid" + ); + assert!( + n.0 < expr.inner.len(), + "node index must be valid" + ); + } + NodeInternal::BinaryOperation(i, n1, n2) => { + assert!( + i.0 < expr.interns.len(), + "intern index must be valid" + ); + assert!( + n1.0 < expr.inner.len(), + "node index must be valid" + ); + assert!( + n2.0 < expr.inner.len(), + "node index must be valid" + ); + assert!( + !matches!( + expr.inner[n1.0], + NodeInternal::Join(_, _) + ), + "binary operations must not reference joins" + ); + assert!( + !matches!( + expr.inner[n2.0], + NodeInternal::Join(_, _) + ), + "binary operations must not reference joins" + ); + } + NodeInternal::Join(n1, n2) => { + assert!( + n1.0 < expr.inner.len(), + "node index must be valid" + ); + assert!( + !matches!( + expr.inner[n1.0], + NodeInternal::Join(_, _) + ), + "operations must not have joins as children" + ); + assert!( + n2.0 < expr.inner.len(), + "node index must be valid" + ); + } + NodeInternal::Variable(i) => { + assert!( + i.0 < expr.interns.len(), + "intern index must be valid" + ); + } + _ => (), + } + } + + //TODO: the graph must never contain a cycle. + //TODO: validate the layout of the entire expression against the + // reference. + } + } + + prop_state_machine! { + #![proptest_config(ProptestConfig { + // allow more rejects to make shrinking results better + max_global_rejects: 1_000_000, + + // disable failure persistence so miri works + #[cfg(miri)] + failure_persistence: None, + ..ProptestConfig::default() + })] + + #[test] + fn matches_state_machine(sequential 1..100 => ExpressionWrapper); + } +} diff --git a/crates/core/src/hive/group.rs b/crates/core/src/hive/group.rs index 32d070e..9217897 100644 --- a/crates/core/src/hive/group.rs +++ b/crates/core/src/hive/group.rs @@ -1,4 +1,4 @@ -//! An implementation of the individual memory blocks that make up a [`Hive`]. +//! An implementation of the individual memory blocks that make up a hive. use core::{ mem::{self, ManuallyDrop}, @@ -27,6 +27,7 @@ where /// implementation. #[repr(C, packed)] #[derive(Copy, Clone)] +#[cfg_attr(feature = "core-fmt", derive(Debug))] pub struct FreeList where Sk: skipfield::SkipfieldType, @@ -39,6 +40,7 @@ where } /// A doubly-linked `Group` of `T` with a skipfield type of `Sk`. +#[cfg_attr(feature = "core-fmt", derive(Debug))] pub struct Group where Sk: skipfield::SkipfieldType, @@ -74,7 +76,7 @@ where /// Pointer to the previous [`Group`] with erased elements. pub previous_with_erasures: Option>>, - /// Number assigned to this group in the [`Hive`]. + /// Number assigned to this group. pub number: usize, } @@ -98,7 +100,23 @@ const fn compute_element_allocation_size() -> usize { } else { t_align } + } else if sk2_size > t_size { + sk2_size } else { - if sk2_size > t_size { sk2_size } else { t_size } + t_size + } +} + +#[cfg(all(test, feature = "core-error"))] +mod test { + use super::*; + + #[test] + fn validate_element_allocation_size() { + assert_eq!( + Group::::ELEMENT_ALLOCATION_SIZE, + 4, + "element allocation size with T = u32, Sk = u8 is 4" + ); } } diff --git a/crates/core/src/hive/mod.rs b/crates/core/src/hive/mod.rs index 17568e3..ad14725 100644 --- a/crates/core/src/hive/mod.rs +++ b/crates/core/src/hive/mod.rs @@ -12,6 +12,7 @@ // parameters into a separate struct to reduce the amount of code // generated, akin to what the standard library does with `RawVec` and // `RawVecInner` +//TODO: try_reserve_exact, reserve_exact use alloc::alloc::{Allocator, Global, Layout}; use core::{cmp, mem, ptr::NonNull}; @@ -26,6 +27,7 @@ pub mod group; pub mod skipfield; /// An implementation of a bucket array using a skipfield. +#[cfg_attr(feature = "core-fmt", derive(Debug))] pub struct Hive where Sk: skipfield::SkipfieldType, @@ -127,6 +129,14 @@ where cmp::min(Sk::from_usize(adaptive_size), max_capacity), ); + //NOTE: anything that calls `panic` indirectly or does anything that + // touches standard output requires core::fmt :skull: + #[cfg(feature = "core-fmt")] + debug_assert!( + max_capacity >= min_capacity, + "maximum capacity bound is greater than or equal to the minimum capacity bound" + ); + (min_capacity, max_capacity) } @@ -161,6 +171,9 @@ where /// exceeds `isize::MAX` bytes. #[cfg(feature = "core-fmt")] pub fn with_capacity_in(capacity: usize, alloc: A) -> Self { + //PANIC: this is acceptable as the panic is mentioned above and it is + // used to assert an invariant. + #[allow(clippy::expect_used)] Self::try_with_capacity_in(capacity, alloc) .expect("allocation should not fail") } @@ -181,6 +194,73 @@ where capacity: usize, alloc: A, ) -> Result { + let mut hive = Self::new_in(alloc); + hive.try_reserve(capacity)?; + Ok(hive) + } + + /// Reserves capacity for at least `additional` more elements to be + /// inserted in the given `Hive`. + /// + /// The collection may reserve more space to avoid future allocations. + /// After calling `reserve`, the capacity will be greater than or equal to + /// `self.len() + additional`. Does nothing if the capacity is already + /// sufficient. + /// + /// # Panics + /// + /// Panics if the allocator reports allocation failure or if the new + /// capacity exceeds `isize::MAX` bytes. + #[cfg(feature = "core-fmt")] + pub fn reserve(&mut self, additional: usize) { + todo!() + } + + /// Reserves capacity for at least `additional` more elements to be + /// inserted in the given `Hive`. + /// + /// The collection may reserve more space to avoid future allocations. + /// After calling `reserve`, the capacity will be greater than or equal to + /// `self.len() + additional`. Does nothing if the capacity is already + /// sufficient. + /// + /// # Errors + /// + /// Returns an error if the allocator reports allocation failure or if the + /// new capacity exceeds `isize::MAX` bytes. + pub fn try_reserve(&mut self, additional: usize) -> Result<(), Error> { + todo!() + } + + /// Checks if the container needs to grow to accommodate `additional` more + /// elements + #[inline] + fn needs_to_grow(&self, additional: usize) -> bool { + additional > self.capacity.wrapping_sub(self.size) + } + + /// Grow the `Hive` by the given amount, leaving room for more + /// elements than necessary. + /// + /// # Errors + /// + /// Returns an error if the allocator reports allocation failure or if the + /// new capacity exceeds `isize::MAX` bytes. + fn grow_amortized(&mut self, additional: usize) -> Result<(), Error> { + #[cfg(feature = "core-fmt")] + debug_assert!( + additional > 0, + "at least space enough for one element will be added" + ); + + if mem::size_of::() == 0 { + //NOTE: akin to raw_vec in alloc, if we get here with a zero + // sized type, since the capacity is definitionally full when + // it is holding one, the hive would necessarily + // be overfull + return Err(Error::CapacityOverflow); + } + todo!() } } @@ -202,7 +282,7 @@ pub enum Error { /// Allocation size exceeded `isize::MAX`. CapacityOverflow, - /// An unspecified allocation error occurred. + /// Unspecified allocation error. /// /// The layout used during allocation is provided for troubleshooting /// where that is possible. Ideally, future revisions of the diff --git a/crates/core/src/hive/skipfield.rs b/crates/core/src/hive/skipfield.rs index 0fb2ab4..16a3dcb 100644 --- a/crates/core/src/hive/skipfield.rs +++ b/crates/core/src/hive/skipfield.rs @@ -4,7 +4,10 @@ //! types may be used by implementing this trait on them; the trait is not //! sealed. -use core::ops::{Add, AddAssign, Sub, SubAssign}; +use core::{ + cmp, + ops::{Add, AddAssign, Sub, SubAssign}, +}; /// Trait describing integral types in a generic way suitable for use as the /// element type of a skipfield. @@ -21,9 +24,13 @@ pub trait SkipfieldType: const ONE: Self; /// Conversion method from `usize` using `as` or an equivalent + /// + /// Caps the value of the input by the maximum of `Self`. fn from_usize(u: usize) -> Self; /// Conversion method from `isize` using `as` or an equivalent + /// + /// Caps the value of the input by the maximum of `Self`. fn from_isize(i: isize) -> Self; } @@ -34,12 +41,12 @@ impl SkipfieldType for u16 { #[inline(always)] fn from_usize(u: usize) -> Self { - u as u16 + cmp::min(u, Self::MAXIMUM as usize) as u16 } #[inline(always)] fn from_isize(i: isize) -> Self { - i as u16 + cmp::min(i, Self::MAXIMUM as isize) as u16 } } @@ -50,11 +57,11 @@ impl SkipfieldType for u8 { #[inline(always)] fn from_usize(u: usize) -> Self { - u as u8 + cmp::min(u, Self::MAXIMUM as usize) as u8 } #[inline(always)] fn from_isize(i: isize) -> Self { - i as u8 + cmp::min(i, Self::MAXIMUM as isize) as u8 } } diff --git a/crates/core/src/lib.rs b/crates/core/src/lib.rs index 330ad58..54c04c5 100644 --- a/crates/core/src/lib.rs +++ b/crates/core/src/lib.rs @@ -1,4 +1,4 @@ -#![no_std] +#![cfg_attr(not(test), no_std)] #![warn( clippy::cargo_common_metadata, clippy::dbg_macro, @@ -54,9 +54,14 @@ clippy::wildcard_imports )] #![feature(allocator_api)] +#![feature(vec_push_within_capacity)] +#![feature(trusted_len)] +#![feature(assert_matches)] //! A highly portable computer algebra system library implemented in Rust extern crate alloc; +pub mod expressions; pub mod hive; +pub mod numerics; diff --git a/crates/core/src/numerics/mod.rs b/crates/core/src/numerics/mod.rs new file mode 100644 index 0000000..9e3d782 --- /dev/null +++ b/crates/core/src/numerics/mod.rs @@ -0,0 +1,102 @@ +//! Number traits and utilities. + +use core::{ + mem, + ops::{ + Add, AddAssign, Div, DivAssign, Mul, MulAssign, Rem, RemAssign, Sub, + SubAssign, + }, +}; + +pub mod rational; + +//NOTE: these should probably be broken up a bunch; they're only implemented +// this way because i didn't really need much else to implement +// rationals. ideally we should break them up on the lines of rings, +// groups, fields, unique factorization domains, principal ideal domains, +// ordered rings, etc. this would make further extensions to +// functionality in azimuth easier. + +/// Integer. +pub trait Integer: + Eq + + Ord + + Mul + + Div + + Add + + Sub + + Rem + + MulAssign + + DivAssign + + AddAssign + + SubAssign + + RemAssign + + Sized + + Copy + + Clone +{ + //NOTE: `Neg` is neglected here as the primary purpose of this trait is + // to abstract over integers for the `Rational` type. as stated + // above, i'd like to adjust the way all of this works anyway for + // future extensions, but for now, we don't really need it. + + //NOTE: `Copy` is required for now as bignum support is not a requirement + // yet and this makes implementation simpler. + + /// Additive identity. + const ZERO: Self; + + /// Multiplicative identity. + const ONE: Self; + + /// Maximum attainable value. + const MAX: Self; + + /// Minimum attainable value. + const MIN: Self; + + /// Calculates the greatest common divisor of this number and another + fn gcd(&self, rhs: &Self) -> Self; +} + +/// Helper macro for the implementation of `Integer` +macro_rules! integer_impl { + ($type:ty, $zero:expr, $one:expr, $max:expr, $min:expr) => { + impl Integer for $type { + const ZERO: Self = $zero; + const ONE: Self = $one; + const MAX: Self = $max; + const MIN: Self = $min; + + #[inline] + fn gcd(&self, rhs: &Self) -> Self { + let mut m = *self; + let mut n = *rhs; + + if m < n { + mem::swap(&mut m, &mut n); + } + + while n != 0 { + let t = n; + n = m % n; + m = t; + } + + m + } + } + }; +} + +integer_impl!(i8, 0, 1, i8::MAX, i8::MIN); +integer_impl!(i16, 0, 1, i16::MAX, i16::MIN); +integer_impl!(i32, 0, 1, i32::MAX, i32::MIN); +integer_impl!(i64, 0, 1, i64::MAX, i64::MIN); +integer_impl!(i128, 0, 1, i128::MAX, i128::MIN); + +integer_impl!(u8, 0, 1, u8::MAX, u8::MIN); +integer_impl!(u16, 0, 1, u16::MAX, u16::MIN); +integer_impl!(u32, 0, 1, u32::MAX, u32::MIN); +integer_impl!(u64, 0, 1, u64::MAX, u64::MIN); +integer_impl!(u128, 0, 1, u128::MAX, u128::MIN); diff --git a/crates/core/src/numerics/rational.rs b/crates/core/src/numerics/rational.rs new file mode 100644 index 0000000..e814dbd --- /dev/null +++ b/crates/core/src/numerics/rational.rs @@ -0,0 +1,147 @@ +//! Rational numbers. + +//TODO: implement overloads for `Rational`. +//TODO: reduce rational numbers before construction is finished. + +use crate::numerics::Integer; + +#[cfg(feature = "core-fmt")] +use core::fmt; + +#[cfg(feature = "core-error")] +use core::error; + +/// Rational number. +/// +/// All operations +#[derive(Eq, PartialEq, Copy, Clone)] +#[cfg_attr(feature = "core-fmt", derive(Debug))] +pub struct Rational { + numerator: T, + denominator: T, +} + +impl Rational +where + T: Integer, +{ + /// Constructs a new `Rational`. + /// + /// # Panics + /// + /// Panics if the denominator is zero. + #[cfg(feature = "core-fmt")] + pub fn new(numerator: T, denominator: T) -> Self { + //PANIC: this is acceptable as the panic is mentioned above and it is + // used to assert an invariant. + #[allow(clippy::expect_used)] + Self::try_new(numerator, denominator) + .expect("denominator should not be zero") + } + + /// Constructs a new `Rational`. + /// + /// # Errors + /// + /// Returns an error if the denominator is zero. + pub fn try_new(numerator: T, denominator: T) -> Result { + if denominator == T::ZERO { + return Err(Error::ZeroDenominator); + } + let mut rational = Self { + numerator, + denominator, + }; + rational.reduce(); + Ok(rational) + } + + /// Converts the rational into its components. + #[inline] + pub fn into_parts(self) -> (T, T) { + (self.numerator, self.denominator) + } + + /// Returns both the numerator and denominator of the rational. + #[inline] + pub fn as_parts(&self) -> (&T, &T) { + (&self.numerator, &self.denominator) + } + + /// Returns the numerator of the rational. + #[inline] + pub fn numerator(&self) -> &T { + &self.numerator + } + + /// Returns the denominator of the rational. + #[inline] + pub fn denominator(&self) -> &T { + &self.denominator + } + + /// Reduces the rational so that the numerator and denominator have no + /// common factors and the denominator is greater than the zero element. + fn reduce(&mut self) { + if self.numerator == T::ZERO { + self.denominator = T::ONE; + return; + } + + if self.numerator == self.denominator { + self.numerator = T::ONE; + self.denominator = T::ONE; + return; + } + + let gcd = self.numerator.gcd(&self.denominator); + + self.numerator /= gcd; + self.denominator /= gcd; + + if self.denominator < T::ZERO { + self.numerator = T::ZERO - self.numerator; + self.denominator = T::ZERO - self.denominator; + } + } +} + +/// Representation of an error that occurred within [`Rational`]. +#[non_exhaustive] +#[derive(Eq, PartialEq)] +#[cfg_attr(feature = "core-fmt", derive(Debug))] +pub enum Error { + /// Denominator was equal to zero. + ZeroDenominator, +} + +#[cfg(feature = "core-fmt")] +impl fmt::Display for Error { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + fmt.write_str("rational number construction failed")?; + let reason = match self { + Error::ZeroDenominator => " because the denominator was zero", + }; + fmt.write_str(reason) + } +} + +#[cfg(feature = "core-error")] +impl error::Error for Error {} + +#[cfg(all(test, feature = "core-error"))] +mod test { + #![allow(clippy::expect_used)] + + use super::*; + + #[test] + fn rational_sanity() { + assert_eq!( + Rational::try_new(6, 3) + .expect("unable to construct a rational number") + .into_parts(), + (2, 1) + ); + } +} -- cgit 1.4.1-2-gfad0