Home > Net >  Cast to inferred function type in Rust
Cast to inferred function type in Rust

Time:10-01

Let's say we have a type FARPROC that we receive as result of certain Win32 operations, such as GetProcAddress. It's definition is as follows:

pub type FARPROC = unsafe extern "system" fn() -> isize;

This type then needs to be converted to a valid function type so we can call it. As developers, we know the function signature (from documentation), but we need to pass that knowledge onto the compiler. A straightforward approach is to use transmute with explicit function type as following:

let address: FARPROC = unsafe { GetProcAddress(module, proc).unwrap() };
let function: extern "system" fn(i32) = unsafe { transmute (&address) };

But what if we wanted to infer the function type from one of the existing functions that we defined in our rust code? Let's say we have some function with lengthy definition:

pub fn foo(arg1: i32, arg2: c_void, arg3: *const c_char) {
   // snip
}

In order to keep our code DRY, is it possible to use that function type with transmute?

My attempted solution looks like the following snippet:


pub fn cast_to_function<F>(address: FARPROC, _fn: &F) -> F {
    unsafe { transmute_copy(&address) }
}

/// Usage

let address: FARPROC = unsafe { GetProcAddress(module, proc).unwrap() };
let function = cast_to_function(address, &foo);
function(1, ...);

This attempted solution is based on the similar C code that is working:

template<typename T>
T FnCast(void* fnToCast, T pFnCastTo) {
    return (T)fnToCast;
}

/// Usage

void bar(int arg1, void* arg2, const char* arg3){
   // snip
}

auto address = (void*) GetProcAddress(...); 
auto function = FnCast(address, &bar);
function(1, address, "bar");

The problem with my attempted solution in rust is that cast_to_function always returns the function with address that points to the referenced function, rather than an address that points to the one which we provided. Hence, given what was laid out so far, is it possible to intelligently infer the function type from its definition and cast arbitrary type to it?

CodePudding user response:

To understand why your attempted solution doesn't work, we must know about the different kinds of function types in Rust; see this answer. When you make this call:

let function = cast_to_function(address, &foo);

the argument &foo is a reference to the specific function item foo; it is not a function pointer like address. Due to the type signature of cast_to_function(), this means the return type T is that same function item. Function items are zero-sized types, so the transmute_copy() inside cast_to_function() doesn't do anything.

The simplest solution is to cast foo to a function pointer before calling cast_to_function():

let function = cast_to_function(address, foo as fn(_, _));

(you'll also need to change cast_to_function()'s signature to take _fn by value: _fn: F).

This is still a bit of boilerplate with the manual casting, though as demonstrated you can use placeholders (_) instead of writing out each argument type. If you get the number of arguments wrong you just get a compile error (you can't accidentally cast foo to a wrong signature).


Getting rid of the required casting in the caller code is more complicated. I came up with this solution, which uses a trait to do the casting, and uses a macro to write the impls for functions of varying arity:

/// Casts a function pointer to a new function pointer, with the new type derived from a
/// template function.
///
/// SAFETY:
/// - `f` must be a function pointer (e.g. `fn(i32) -> i32`)
/// - `template`'s type must be a valid function signature for `f`
unsafe fn fn_ptr_cast<T, U, V>(fn_ptr: T, _template_fn: U) -> V
where
    T: Copy   'static,
    U: FnPtrCast<V>,
{
    debug_assert_eq!(mem::size_of::<T>(), mem::size_of::<usize>());

    U::fn_ptr_cast(fn_ptr)
}

unsafe trait FnPtrCast<U> {
    unsafe fn fn_ptr_cast<T>(fn_ptr: T) -> U;
}

macro_rules! impl_fn_cast {
    ( $($arg:ident),* ) => {
        unsafe impl<Fun, Out, $($arg),*> FnPtrCast<fn($($arg),*) -> Out> for Fun
        where
            Fun: Fn($($arg),*) -> Out,
        {
            unsafe fn fn_ptr_cast<T>(fn_ptr: T) -> fn($($arg),*) -> Out {
                ::std::mem::transmute(::std::ptr::read(&fn_ptr as *const T as *const *const T))
            }
        }
    }
}

impl_fn_cast!();
impl_fn_cast!(A);
impl_fn_cast!(A, B);
impl_fn_cast!(A, B, C);
impl_fn_cast!(A, B, C, D);
impl_fn_cast!(A, B, C, D, E);
impl_fn_cast!(A, B, C, D, E, F);
impl_fn_cast!(A, B, C, D, E, F, G);
impl_fn_cast!(A, B, C, D, E, F, G, H);

This lets you do

let function = fn_ptr_cast(address, foo);

without any manual casting.

Note: This last solution assumes function pointers and data pointers are the same size, which may not always be the case on some platforms.

  • Related