Home > Back-end >  How to specialize a template function for different data types in which the procedures are similar?
How to specialize a template function for different data types in which the procedures are similar?

Time:12-13

For example, I want to implement a matrix multiplication template function using AVX2. (Suppose "Matrix" is a well implemented template class)

Matrix<T> matmul(const Matrix<T>& mat1, const Matrix<T>& mat2) {
    if (typeid(T).name() == typeid(float).name()) {
        //using __m256 to store float
        //using __m256_load_ps __m256_mul_ps __m256_add_ps
    } else if (typeid(T).name() == typeid(double).name()) {
        //using __m256d to store double
        //using __m256d_load_pd __m256d_mul_pd __m256d_add_pd
    } else {
        //...
    }
}

As there is no "variable" for data types, the program can't determine whether it should use __m256 or __m256d or anything else, thus making the code very long and awkward. Is there another way to avoid this?

CodePudding user response:

In C 17 and later, you can use if constexpr:

#include <type_traits>

Matrix<T> matmul(const Matrix<T>& mat1, const Matrix<T>& mat2) {
    if constexpr (std::is_same_v<T, float>) {
        //using __m256 to store float
        //using __m256_load_ps __m256_mul_ps __m256_add_ps
    } else if constexpr (std::is_same_v<T, double>) {
        //using __m256d to store double
        //using __m256d_load_pd __m256d_mul_pd __m256d_add_pd
    } else {
        //...
    }
}

Otherwise, just use overloads:

Matrix<float> matmul(const Matrix<float>& mat1, const Matrix<float>& mat2) {
    //using __m256 to store float
    //using __m256_load_ps __m256_mul_ps __m256_add_ps
}

Matrix<double> matmul(const Matrix<double>& mat1, const Matrix<double>& mat2) {
    //using __m256d to store double
    //using __m256d_load_pd __m256d_mul_pd __m256d_add_pd
}

...

CodePudding user response:

First, you could create overloads for the functions _mm256_load_* and _mm256_mul_* etc.:

namespace foo {
inline __m256 mm256_load(float const& a) {
    return _mm256_load_ps(&a);
}
inline __m256d mm256_load(double const& a) {
    return _mm256_load_pd(&a);
}

inline __m256 mm256_mul(__m256 m1, __m256 m2) {
    return _mm256_mul_ps(m1, m2);
}
inline __m256d mm256_mul(__m256d m1, __m256d m2) {
    return _mm256_mul_pd(m1, m2);
}
// add more functions here
} // namespace foo

You could then create a type trait to give you the proper AVX2 type for float and double:

#include <type_traits>

namespace foo {
template<class T> struct floatstore;

template<> struct floatstore<float> {
    using value_type = __m256;
};
template<> struct floatstore<double> {
    using value_type = __m256d;
};

template<class T>
using floatstore_t = typename floatstore<T>::value_type;
} // namespace foo

With that, your final function could use the above like this:

template<class T>
Matrix<T> matmul(const Matrix<T>& mat1, const Matrix<T>& mat2) {
    T some_float = ...;              // float or double
    foo::floatstore_t<T> a_variable; // __m256 or __m256d

    a_variable = foo::mm256_load(some_float)
}
  • Related