Home > Blockchain >  Generic function over collections of numbers
Generic function over collections of numbers

Time:01-21

I'm trying to compare elements of structs using the assert_approx_eq! macro. The structs I'm working on are

#[derive(Copy, Clone, Debug, PartialEq)]
pub struct Matrix3D {
    n: [[f64; 3]; 3],
}

#[derive(Clone, Copy, Debug, PartialEq)]
pub struct Vector3D {
    pub x: f64,
    pub y: f64,
    pub z: f64,
}

I've implemented the IntoIterator trait for both types with the help of a separate iterator struct

impl IntoIterator for Matrix3D {
    type Item = f64;
    type IntoIter = Matrix3DIterator;

    fn into_iter(self) -> Self::IntoIter {
        Matrix3DIterator {
            n: self.n,
            index: 0,
        }
    }
}

pub struct Matrix3DIterator {
    n: [[f64; 3]; 3],
    index: usize,
}

impl Iterator for Matrix3DIterator {
    type Item = f64;

    fn next(&mut self) -> Option<Self::Item> {
        let i = self.index / 3;
        let j = self.index % 3;
        if i < 3 && j < 3 {
            self.index  = 1;
            Some(self.n[j][i])
        } else {
            None
        }
    }
}
impl IntoIterator for Vector3D {
    type Item = f64;
    type IntoIter = Vector3DIterator;

    fn into_iter(self) -> Self::IntoIter {
        Vector3DIterator { x: self.x, y: self.y, z: self.z, index: 0}
    }
}

pub struct Vector3DIterator {
    x: f64,
    y: f64,
    z: f64,
    index: usize,
}

impl Iterator for Vector3DIterator {
    type Item = f64;

    fn next(&mut self) -> Option<Self::Item> {
        let result = match self.index {
            0 => Some(self.x),
            1 => Some(self.y),
            2 => Some(self.z),
            _ => None
        };
       self.index  = 1; 
       result
    }
}

When writing tests for these structs I made a helper function that iterates over them and compares each element. I'm struggling to express this is a generic way such that I can use the same function to compare two matrices or two vectors with each other. Specifically, the type of the collection must support the subtraction operation and implement an abs function for the assert_approx_eq! macro. What I've got so far is

use num_traits::Float;
fn elementwise_approx_comparison<I: IntoIterator>(result: I, expected: I) -> () 
    where I::Item: Float,
          I::Item: std::fmt::Debug {
        for (r, e) in std::iter::zip(result, expected) {
            assert_approx_eq!(r, e);
        }
    }

which gives the following error

error[E0308]: mismatched types
   --> src/matrix.rs:411:13
    |
411 |             assert_approx_eq!(r, e);
    |             ^^^^^^^^^^^^^^^^^^^^^^^ expected associated type, found floating-point number
    |
    = note: expected associated type `<I as IntoIterator>::Item`
                          found type `{float}`
    = note: this error originates in the macro `assert_approx_eq` (in Nightly builds, run with -Z macro-backtrace for more info)
help: consider constraining the associated type `<I as IntoIterator>::Item` to `{float}`
    |
407 |     fn elementwise_approx_comparison<I: IntoIterator<Item = {float}>>(result: I, expected: I) -> () 
    |                                                                     

What does it mean that the compiler expected an associated type but found floating-point number instead? How do I express that the function elementwise_approx_comparison should accept a collection I that can be made into an iterator, and iterates over (in this case) floating type numbers?

CodePudding user response:

The problem is that assert_approx_eq tries to compare the difference between the two values and a literal constant, but Rust doesn't know how to unify the type of this literal constant with the actual type of the values. However assert_approx_eq allows you to specify the comparison threshold, and Float defines an epsilon() method that gives you a "small positive value" suitable for this kind of use. Therefore you can do:

use num_traits::Float;
fn elementwise_approx_comparison<I: IntoIterator>(result: I, expected: I) -> () 
    where I::Item: Float,
          I::Item: std::fmt::Debug 
{
    for (r, e) in std::iter::zip(result, expected) {
        assert_approx_eq!(r, e, Float::epsilon());
    }
}

or if f32::EPSILON is too small for you, you can use Float::from to specify the threshold yourself:

use num_traits::Float;
fn elementwise_approx_comparison<I: IntoIterator>(result: I, expected: I) -> () 
    where I::Item: Float,
          I::Item: std::fmt::Debug 
{
    for (r, e) in std::iter::zip(result, expected) {
        assert_approx_eq!(r, e, Float::from(1e-6f32));
    }
}

CodePudding user response:

This is because the macro uses a floating point number in the definition. You can additionally require the type to be comparable to f64. However this effectively restrict the type to be f64 which is not ideal.

fn elementwise_approx_comparison<I: IntoIterator>(result: I, expected: I) -> ()
where
    <I as IntoIterator>::Item: Float   PartialOrd<f64>,
    I::Item: std::fmt::Debug,
{
    for (r, e) in std::iter::zip(result, expected) {
        assert_approx_eq!(r, e);
    }
}

Another solution would be to require the caller to construct the epsilon for the function.

fn elementwise_approx_comparison<I: IntoIterator<Item = F>, F: Float>(
    result: I,
    expected: I,
    eps: F,
) -> ()
where
    I::Item: std::fmt::Debug,
{
    for (r, e) in std::iter::zip(result, expected) {
        assert_approx_eq!(r, e, eps);
    }
}
  • Related