rust/fpnum/src/lib.rs
branchtransitional_engine
changeset 16081 6633961698ad
parent 16079 624b74443b53
--- a/rust/fpnum/src/lib.rs	Thu Dec 19 12:43:38 2024 +0100
+++ b/rust/fpnum/src/lib.rs	Thu Dec 19 14:18:55 2024 +0100
@@ -1,4 +1,5 @@
 use std::{cmp, ops};
+use std::marker::PhantomData;
 use saturate::SaturatingInto;
 
 const POSITIVE_MASK: u64 = 0x0000_0000_0000_0000;
@@ -13,16 +14,20 @@
     }
 }
 
+struct FracBits<const N: u8>;
 #[derive(Clone, Debug, Copy)]
-pub struct FPNum {
+pub struct FixedPoint<const FRAC_BITS: u8> {
     sign_mask: u64,
     value: u64,
+    _marker: PhantomData<FracBits<FRAC_BITS>>,
 }
 
-impl FPNum {
+pub type FPNum = FixedPoint<20>;
+
+impl<const FRAC_BITS: u8> FixedPoint<FRAC_BITS> {
     #[inline]
     pub fn new(numerator: i32, denominator: u32) -> Self {
-        FPNum::from(numerator) / denominator
+        Self::from(numerator) / denominator
     }
 
     #[inline]
@@ -50,24 +55,26 @@
         Self {
             sign_mask: POSITIVE_MASK,
             value: self.value,
+            _marker: self._marker,
         }
     }
 
     #[inline]
-    pub fn round(&self) -> i32 {
-        ((self.value >> 32) as i32 ^ self.sign_mask as i32).wrapping_sub(self.sign_mask as i32)
+    pub fn round(&self) -> i64 {
+        ((self.value >> FRAC_BITS) as i64 ^ self.sign_mask as i64).wrapping_sub(self.sign_mask as i64)
     }
 
     #[inline]
-    pub const fn abs_round(&self) -> u32 {
-        (self.value >> 32) as u32
+    pub const fn abs_round(&self) -> u64 {
+        self.value >> FRAC_BITS
     }
 
     #[inline]
     pub fn sqr(&self) -> Self {
         Self {
             sign_mask: 0,
-            value: ((self.value as u128).pow(2) >> 32).saturating_into(),
+            value: ((self.value as u128).pow(2) >> FRAC_BITS).saturating_into(),
+            _marker: self._marker
         }
     }
 
@@ -77,92 +84,95 @@
 
         Self {
             sign_mask: POSITIVE_MASK,
-            value: integral_sqrt(self.value) << 16,
+            value: integral_sqrt(self.value) << (FRAC_BITS / 2),
+            _marker: self._marker
         }
     }
 
     #[inline]
-    pub fn with_sign(&self, is_negative: bool) -> FPNum {
-        FPNum {
+    pub fn with_sign(&self, is_negative: bool) -> Self {
+        Self {
             sign_mask: bool_mask(is_negative),
             ..*self
         }
     }
 
     #[inline]
-    pub const fn with_sign_as(self, other: FPNum) -> FPNum {
-        FPNum {
+    pub const fn with_sign_as(self, other: Self) -> Self {
+        Self {
             sign_mask: other.sign_mask,
             ..self
         }
     }
-
+/*
     #[inline]
     pub const fn point(self) -> FPPoint {
         FPPoint::new(self, self)
     }
-
+*/
     #[inline]
     const fn temp_i128(self) -> i128 {
         ((self.value ^ self.sign_mask) as i128).wrapping_sub(self.sign_mask as i128)
     }
 }
 
-impl From<i32> for FPNum {
+impl<const FRAC_BITS: u8> From<i32> for FixedPoint<FRAC_BITS> {
     #[inline]
     fn from(n: i32) -> Self {
-        FPNum {
+        Self {
             sign_mask: bool_mask(n < 0),
-            value: (n.abs() as u64) << 32,
+            value: (n.abs() as u64) << FRAC_BITS,
+            _marker: PhantomData,
         }
     }
 }
 
-impl From<u32> for FPNum {
+impl<const FRAC_BITS: u8> From<u32> for FixedPoint<FRAC_BITS> {
     #[inline]
     fn from(n: u32) -> Self {
         Self {
             sign_mask: POSITIVE_MASK,
-            value: (n as u64) << 32,
+            value: (n as u64) << FRAC_BITS,
+            _marker: PhantomData,
         }
     }
 }
 
-impl From<FPNum> for f64 {
+impl<const FRAC_BITS: u8> From<FixedPoint<FRAC_BITS>> for f64 {
     #[inline]
-    fn from(n: FPNum) -> Self {
+    fn from(n: FixedPoint<FRAC_BITS>) -> Self {
         if n.is_negative() {
-            n.value as f64 / -0x1_0000_0000i64 as f64
+            n.value as f64 / -(1i64 << FRAC_BITS) as f64
         } else {
-            n.value as f64 / 0x1_0000_0000i64 as f64
+            n.value as f64 / (1i64 << FRAC_BITS) as f64
         }
     }
 }
 
-impl PartialEq for FPNum {
+impl<const FRAC_BITS: u8> PartialEq for FixedPoint<FRAC_BITS> {
     #[inline]
     fn eq(&self, other: &Self) -> bool {
         self.value == other.value && (self.sign_mask == other.sign_mask || self.value == 0)
     }
 }
 
-impl Eq for FPNum {}
+impl<const FRAC_BITS: u8> Eq for FixedPoint<FRAC_BITS> {}
 
-impl PartialOrd for FPNum {
+impl<const FRAC_BITS: u8> PartialOrd for FixedPoint<FRAC_BITS> {
     #[inline]
     fn partial_cmp(&self, rhs: &Self) -> Option<cmp::Ordering> {
         Some(self.cmp(rhs))
     }
 }
 
-impl Ord for FPNum {
+impl<const FRAC_BITS: u8> Ord for FixedPoint<FRAC_BITS> {
     #[inline]
     fn cmp(&self, rhs: &Self) -> cmp::Ordering {
         self.temp_i128().cmp(&(rhs.temp_i128()))
     }
 }
 
-impl ops::Add for FPNum {
+impl<const FRAC_BITS: u8> ops::Add for FixedPoint<FRAC_BITS> {
     type Output = Self;
 
     #[inline]
@@ -172,11 +182,12 @@
         Self {
             sign_mask: mask,
             value: ((tmp as u64) ^ mask).wrapping_sub(mask),
+            _marker: PhantomData,
         }
     }
 }
 
-impl ops::Sub for FPNum {
+impl<const FRAC_BITS: u8> ops::Sub for FixedPoint<FRAC_BITS> {
     type Output = Self;
 
     #[inline]
@@ -186,7 +197,7 @@
     }
 }
 
-impl ops::Neg for FPNum {
+impl<const FRAC_BITS: u8> ops::Neg for FixedPoint<FRAC_BITS> {
     type Output = Self;
 
     #[inline]
@@ -194,23 +205,25 @@
         Self {
             sign_mask: !self.sign_mask,
             value: self.value,
+            _marker: PhantomData,
         }
     }
 }
 
-impl ops::Mul for FPNum {
+impl<const FRAC_BITS: u8> ops::Mul for FixedPoint<FRAC_BITS> {
     type Output = Self;
 
     #[inline]
     fn mul(self, rhs: Self) -> Self {
         Self {
             sign_mask: self.sign_mask ^ rhs.sign_mask,
-            value: ((self.value as u128 * rhs.value as u128) >> 32).saturating_into(),
+            value: ((self.value as u128 * rhs.value as u128) >> FRAC_BITS).saturating_into(),
+            _marker: PhantomData,
         }
     }
 }
 
-impl ops::Mul<i32> for FPNum {
+impl<const FRAC_BITS: u8> ops::Mul<i32> for FixedPoint<FRAC_BITS> {
     type Output = Self;
 
     #[inline]
@@ -218,23 +231,25 @@
         Self {
             sign_mask: self.sign_mask ^ bool_mask(rhs < 0),
             value: (self.value as u128 * rhs.abs() as u128).saturating_into(),
+            _marker: PhantomData,
         }
     }
 }
 
-impl ops::Div for FPNum {
+impl<const FRAC_BITS: u8> ops::Div for FixedPoint<FRAC_BITS> {
     type Output = Self;
 
     #[inline]
     fn div(self, rhs: Self) -> Self {
         Self {
             sign_mask: self.sign_mask ^ rhs.sign_mask,
-            value: (((self.value as u128) << 32) / rhs.value as u128).saturating_into(),
+            value: (((self.value as u128) << FRAC_BITS) / rhs.value as u128).saturating_into(),
+            _marker: PhantomData,
         }
     }
 }
 
-impl ops::Div<i32> for FPNum {
+impl<const FRAC_BITS: u8> ops::Div<i32> for FixedPoint<FRAC_BITS> {
     type Output = Self;
 
     #[inline]
@@ -242,11 +257,12 @@
         Self {
             sign_mask: self.sign_mask ^ bool_mask(rhs < 0),
             value: self.value / rhs.abs() as u64,
+            _marker: PhantomData,
         }
     }
 }
 
-impl ops::Div<u32> for FPNum {
+impl<const FRAC_BITS: u8> ops::Div<u32> for FixedPoint<FRAC_BITS> {
     type Output = Self;
 
     #[inline]
@@ -254,6 +270,7 @@
         Self {
             sign_mask: self.sign_mask,
             value: self.value / rhs as u64,
+            _marker: PhantomData,
         }
     }
 }
@@ -309,6 +326,7 @@
         FPNum {
             sign_mask: self.x_sign_mask as i32 as u64,
             value: self.x_value,
+            _marker: PhantomData,
         }
     }
 
@@ -317,6 +335,7 @@
         FPNum {
             sign_mask: self.y_sign_mask as i32 as u64,
             value: self.y_value,
+            _marker: PhantomData,
         }
     }
 
@@ -346,6 +365,7 @@
             FPNum {
                 sign_mask: POSITIVE_MASK,
                 value: integral_sqrt_ext(sqr),
+                _marker: PhantomData,
             }
         }
     }
@@ -504,16 +524,17 @@
 }
 
 #[inline]
-pub fn distance<T>(x: T, y: T) -> FPNum
+pub fn distance<T, const FRAC_BITS: u8>(x: T, y: T) -> FixedPoint<FRAC_BITS>
 where
     T: Into<i128> + std::fmt::Debug,
 {
-    let [x_squared, y_squared] = [x, y].map(|i| (i.into().pow(2) as u128).saturating_mul(2^64));
+    let [x_squared, y_squared] = [x, y].map(|i| (i.into().pow(2) as u128).saturating_mul(1 << FRAC_BITS << FRAC_BITS));
     let sqr: u128 = x_squared.saturating_add(y_squared);
 
-    FPNum {
+    FixedPoint {
         sign_mask: POSITIVE_MASK,
         value: integral_sqrt_ext(sqr),
+        _marker: PhantomData,
     }
 }