Home > Enterprise >  Correctly make blanket implementation of argmax trait for iterators
Correctly make blanket implementation of argmax trait for iterators

Time:09-22

I decided on trying to make a trait in Rust using blanket implementation, and the test method to implement was a trait that returns the argmax over an iterator together with the element. Right now the implementation is

use num::Bounded;

trait Argmax<T> {
    fn argmax(self) -> (usize, T);
}

impl<I, T> Argmax<T> for I
where
    I: Iterator<Item = T>,
    T: std::cmp::PartialOrd   Bounded,
{
    fn argmax(self) -> (usize, T) {
        self.enumerate()
            .fold((0, T::min_value()), |(i_max, val_max), (i, val)| {
                if val >= val_max {
                    (i, val)
                } else {
                    (i_max, val_max)
                }
            })
    }
}

Testing it with this code

fn main() {
    let v = vec![1., 2., 3., 4., 2., 3.];
    println!("v: {:?}", v);
    let (i_max, v_max) = v.iter().copied().argmax();
    println!("i_max: {}\nv_max: {}", i_max, v_max);
}

works, while

fn main() {
    let v = vec![1., 2., 3., 4., 2., 3.];
    println!("v: {:?}", v);
    let (i_max, v_max) = v.iter().argmax();
    println!("i_max: {}\nv_max: {}", i_max, v_max);
}

doesn't compile, and gives these errors:

  --> src/main.rs:27:35
   |
27 |     let (i_max, v_max) = v.iter().argmax();
   |                                   ^^^^^^ method cannot be called on `std::slice::Iter<'_, {float}>` due to unsatisfied trait bounds
   |
   = note: the following trait bounds were not satisfied:
           `<&std::slice::Iter<'_, {float}> as Iterator>::Item = _`
           which is required by `&std::slice::Iter<'_, {float}>: Argmax<_>`
           `&std::slice::Iter<'_, {float}>: Iterator`
           which is required by `&std::slice::Iter<'_, {float}>: Argmax<_>`

error: aborting due to previous error

For more information about this error, try `rustc --explain E0599`.

I figure that the problem originates from the fact that .iter() loops over references, while .iter().copied() loops over actual values, but I still can't wrap my head around the error message and how to make it generic and working with looping over references.

CodePudding user response:

I still can't wrap my head around the error message

Unfortunately, the error message is cryptic, because it doesn't tell you what <&std::slice::Iter<'_, {float}> as Iterator>::Item is, which is the key fact — just what it isn't. (Possibly it doesn't help that {float}, a not-yet-chosen numeric type, is involved. I'm also not sure what the & is doing there, since there's no reference to an iterator involved.)

However, if you look up the documentation for std::slice::Iter<'a, T> you will find that its item type is &'a T, so in this case, &'a {float}.

This tells you what you already know: the iterator is over references. Unfortunately the error message doesn't tell you much about the remainder of the problem. But if I check out the docs for num::Bounded I find, unsurprisingly, that Bounded is not implemented for references to numbers. This is unsurprising because references must be to values which exist in memory, and so it can be tricky or impossible to construct references which aren't borrowing some existing data structure. (I think it might be possible in this case, but num hasn't implemented that.)

and how to make it generic and working with looping over references.

It's not possible as long as you choose to use the Bounded trait, because Bounded is not implemented for references to primitive numbers, and it's not possible to provide two different blanket implementations for &T and T.

(You could implement Bounded for a type of your own, MyWrapper<f32>, and references to it, but then users have to deal with that wrapper.)

Here are some options:

  1. Keep the code you currently have, and live with the need to write .copied(). It is not at all uncommon to have this situation in other iterators — don't make code more hairy just for the sake of avoiding one extra function call.

  2. Write a version of argmax() with return type Option<(usize, T)>, producing None when the iterator is empty. Then, there is no need to use Bounded and the code will work with only the PartialEq constraint. Also, it will not return a meaningless index and value when the iterator is empty — this is generally considered a virtue in Rust code. The caller can always use .unwrap_or_else() if (0, T::min_value()) is an appropriate answer for their application.

  3. Write a version of argmax() which takes a separate initial value, rather than using T::min_value().

CodePudding user response:

PartialOrd is implemented for references; the issue here is that num::Bounded is not. If you make the argmax function follow the convention of returning None when the iterator is empty (as Iterator::max does), you can get rid of the dependency on the num::Bounded trait entirely:

fn argmax(self) -> Option<(usize, T)> {
    let mut enumerated = self.enumerate();
    enumerated.next().map(move |first| {
        enumerated.fold(first, |(i_max, val_max), (i, val)| {
            if val >= val_max {
                (i, val)
            } else {
                (i_max, val_max)
            }
        })
    })
}

This also allows Argmax to be implemented for iterators over custom (potentially) non-numeric comparable types.

As an aside, you may want to convert the Argmax trait to use an associated type instead of a generic type, just like Iterator.

Playground

  • Related