I'm trying to compare elements of structs using the assert_approx_eq! macro. The structs I'm working on are
#[derive(Copy, Clone, Debug, PartialEq)]
pub struct Matrix3D {
n: [[f64; 3]; 3],
}
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct Vector3D {
pub x: f64,
pub y: f64,
pub z: f64,
}
I've implemented the IntoIterator
trait for both types with the help of a separate iterator struct
impl IntoIterator for Matrix3D {
type Item = f64;
type IntoIter = Matrix3DIterator;
fn into_iter(self) -> Self::IntoIter {
Matrix3DIterator {
n: self.n,
index: 0,
}
}
}
pub struct Matrix3DIterator {
n: [[f64; 3]; 3],
index: usize,
}
impl Iterator for Matrix3DIterator {
type Item = f64;
fn next(&mut self) -> Option<Self::Item> {
let i = self.index / 3;
let j = self.index % 3;
if i < 3 && j < 3 {
self.index = 1;
Some(self.n[j][i])
} else {
None
}
}
}
impl IntoIterator for Vector3D {
type Item = f64;
type IntoIter = Vector3DIterator;
fn into_iter(self) -> Self::IntoIter {
Vector3DIterator { x: self.x, y: self.y, z: self.z, index: 0}
}
}
pub struct Vector3DIterator {
x: f64,
y: f64,
z: f64,
index: usize,
}
impl Iterator for Vector3DIterator {
type Item = f64;
fn next(&mut self) -> Option<Self::Item> {
let result = match self.index {
0 => Some(self.x),
1 => Some(self.y),
2 => Some(self.z),
_ => None
};
self.index = 1;
result
}
}
When writing tests for these structs I made a helper function that iterates over them and compares each element. I'm struggling to express this is a generic way such that I can use the same function to compare two matrices or two vectors with each other. Specifically, the type of the collection must support the subtraction operation and implement an abs
function for the assert_approx_eq!
macro. What I've got so far is
use num_traits::Float;
fn elementwise_approx_comparison<I: IntoIterator>(result: I, expected: I) -> ()
where I::Item: Float,
I::Item: std::fmt::Debug {
for (r, e) in std::iter::zip(result, expected) {
assert_approx_eq!(r, e);
}
}
which gives the following error
error[E0308]: mismatched types
--> src/matrix.rs:411:13
|
411 | assert_approx_eq!(r, e);
| ^^^^^^^^^^^^^^^^^^^^^^^ expected associated type, found floating-point number
|
= note: expected associated type `<I as IntoIterator>::Item`
found type `{float}`
= note: this error originates in the macro `assert_approx_eq` (in Nightly builds, run with -Z macro-backtrace for more info)
help: consider constraining the associated type `<I as IntoIterator>::Item` to `{float}`
|
407 | fn elementwise_approx_comparison<I: IntoIterator<Item = {float}>>(result: I, expected: I) -> ()
|
What does it mean that the compiler expected an associated type but found floating-point number instead? How do I express that the function elementwise_approx_comparison
should accept a collection I
that can be made into an iterator, and iterates over (in this case) floating type numbers?
CodePudding user response:
The problem is that assert_approx_eq
tries to compare the difference between the two values and a literal constant, but Rust doesn't know how to unify the type of this literal constant with the actual type of the values. However assert_approx_eq
allows you to specify the comparison threshold, and Float
defines an epsilon()
method that gives you a "small positive value" suitable for this kind of use. Therefore you can do:
use num_traits::Float;
fn elementwise_approx_comparison<I: IntoIterator>(result: I, expected: I) -> ()
where I::Item: Float,
I::Item: std::fmt::Debug
{
for (r, e) in std::iter::zip(result, expected) {
assert_approx_eq!(r, e, Float::epsilon());
}
}
or if f32::EPSILON
is too small for you, you can use Float::from
to specify the threshold yourself:
use num_traits::Float;
fn elementwise_approx_comparison<I: IntoIterator>(result: I, expected: I) -> ()
where I::Item: Float,
I::Item: std::fmt::Debug
{
for (r, e) in std::iter::zip(result, expected) {
assert_approx_eq!(r, e, Float::from(1e-6f32));
}
}
CodePudding user response:
This is because the macro uses a floating point number in the definition. You can additionally require the type to be comparable to f64. However this effectively restrict the type to be f64 which is not ideal.
fn elementwise_approx_comparison<I: IntoIterator>(result: I, expected: I) -> ()
where
<I as IntoIterator>::Item: Float PartialOrd<f64>,
I::Item: std::fmt::Debug,
{
for (r, e) in std::iter::zip(result, expected) {
assert_approx_eq!(r, e);
}
}
Another solution would be to require the caller to construct the epsilon for the function.
fn elementwise_approx_comparison<I: IntoIterator<Item = F>, F: Float>(
result: I,
expected: I,
eps: F,
) -> ()
where
I::Item: std::fmt::Debug,
{
for (r, e) in std::iter::zip(result, expected) {
assert_approx_eq!(r, e, eps);
}
}