I am trying to make use of the currently unstable feature generic_const_exprs
to allow the users of my library to know the resulting dimensions of the types they generate.
My use case is much more complex, but I've created a minimal example with a reproducible error. The main idea is, that given a Tensor<N>
as input, I want to output a Tensor<M>
, where M
is {N 1}
. A Tensor<N>
is a trait, and it is implemented both for Constant<N>
and for Variable<M>
. This is the code:
#![allow(incomplete_features)]
#![feature(generic_const_exprs)]
struct Variable<const N: usize>;
struct Constant<const N: usize>;
trait Tensor<const N: usize> {
fn get_dim(&self) -> usize {
N
}
}
trait ConvertTo<Y> {
fn convert(&self) -> Y;
}
impl<const N: usize> Tensor<N> for Variable<N> {}
impl<const N: usize> Tensor<N> for Constant<N> {}
impl<const N: usize, const M: usize> ConvertTo<Constant<M>> for Variable<N> {
fn convert(&self) -> Constant<M> {
Constant::<M>
}
}
impl<const N: usize, const M: usize> ConvertTo<Variable<M>> for Constant<N> {
fn convert(&self) -> Variable<M> {
Variable::<M>
}
}
fn convert_plus_one<const N: usize, X, Y>(x: X) -> Y
where
X: Tensor<N> ConvertTo<Y>,
Y: Tensor<{ N 1 }>,
{
x.convert()
}
fn main() {
let x = Constant::<3>;
let y = convert_plus_one(x);
// At this point the compiler should know that y is a Variable<N> with N = 4
// and it implements Tensor<4>, because Tensor<N> is implemented for Variable<N>
assert_eq!(y.get_dim(), 4);
}
And this is the compiler error:
Compiling playground v0.0.1 (/playground)
error[E0277]: the trait bound `Variable<{_: usize}>: Tensor<{ N 1 }>` is not satisfied
--> src/main.rs:41:13
|
41 | let y = convert_plus_one(x);
| ^^^^^^^^^^^^^^^^ the trait `Tensor<{ N 1 }>` is not implemented for `Variable<{_: usize}>`
|
= help: the trait `Tensor<N>` is implemented for `Variable<N>`
note: required by a bound in `convert_plus_one`
--> src/main.rs:34:8
|
31 | fn convert_plus_one<const N: usize, X, Y>(x: X) -> Y
| ---------------- required by a bound in this
...
34 | Y: Tensor<{ N 1 }>,
| ^^^^^^^^^^^^^^^^^ required by this bound in `convert_plus_one`
For more information about this error, try `rustc --explain E0277`.
error: could not compile `playground` due to previous error
I am running out of ideas on how to fix this. Am I missing something, or is this just impossible to do in the current state of generic_const_exprs
?
CodePudding user response:
I managed to make it work by using trait associated types. The trick here was to be able to express my bounds in a single expression.
From this:
where
X: Tensor<N> ConvertTo<Y>,
Y: Tensor<{ N 1 }>,
To this:
where
X: Tensor ConvertTo<{<X as Tensor>::N 1}>,
The original example didn't work because Rust evaluates each trait bound independently. So in one hand it tries to assert that Constant<3>: ConvertTo<?>
, and on the other that ?: Tensor<4>
. Which only makes sense if they are both considered at the same time.
Associated types on traits, allow for the necessary syntax to indeed have all bounds in a single expression, here is the final result, which compiles perfectly:
#![allow(incomplete_features)]
#![feature(generic_const_exprs)]
#![feature(associated_type_bounds)]
struct Variable<const N: usize>;
struct Constant<const N: usize>;
trait Tensor {
const N: usize;
fn get_dim(&self) -> usize {
Self::N
}
}
trait ConvertTo<const N: usize> {
type To;
fn convert(&self) -> Self::To;
}
impl<const N: usize> Tensor for Variable<N> {
const N: usize = N;
}
impl<const N: usize> Tensor for Constant<N> {
const N: usize = N;
}
impl<const N: usize, const M: usize> ConvertTo<M> for Variable<N> {
type To = Constant<M>;
fn convert(&self) -> Self::To {
Constant::<M>
}
}
impl<const N: usize, const M: usize> ConvertTo<M> for Constant<N> {
type To = Variable<M>;
fn convert(&self) -> Self::To {
Variable::<M>
}
}
fn convert_plus_one<X>(x: X) -> <X as ConvertTo<{<X as Tensor>::N 1}>>::To
where
X: Tensor ConvertTo<{<X as Tensor>::N 1}>,
{
x.convert()
}
fn main() {
let x = Constant::<3>;
let y = convert_plus_one(x);
assert_eq!(y.get_dim(), 4);
}
And now I can rest.
CodePudding user response:
Rust is weak with inferring const expressions, at least currently. While the compiler can infer that since there is only one Convert
impl for Constant
Y
must be Variable<_>
, it is unable to infer the N
for this Variable
. If you specify it explicitly, it works:
let y = convert_plus_one::<3, _, Variable<4>>(x);
(I specified the 3
just because inferring it requires the generic_arg_infer
feature, bug it works fine without it).
I believe you can report a bug, I haven't checked for duplicates though.