Home > Blockchain >  Recursive optimization, partial specialization on functions with implicit type deduction
Recursive optimization, partial specialization on functions with implicit type deduction

Time:09-28

It is fully possible to create a recursive function using 'int N' as template parameter for the base case of a function, and specializing it to end when N = 0. But how do I do this for functions where other types are implicitly deduced?

#include <iostream>
#include <Eigen/Dense>
using Mat = Eigen::MatrixXd;
using Vec = Eigen::VectorXd;
using Eigen::MatrixBase;

template <int N, typename D0, typename D1, class Enable = void>
inline Vec recursive(const MatrixBase<D0> &A,
                     const MatrixBase<D1> &b)
{
    return (N % 2) ? recursive<N-1>(A, A * b) : recursive<N-1>(A, A.ldlt().solve(b));
}

template <int N, typename D0, typename D1>
inline Vec recursive <N, D0, D1, typename std::enable_if<N == 0>>(const MatrixBase<D0> &A,
                     const MatrixBase<D1> &b)
{
    return b;
}

int main()
{
    Mat A(2, 2);
    A(0, 0) = 10.0;
    A(1, 1) = 10.0;

    Vec b(2);
    b(0) = 1.0;
    b(1) = 5.0;

    Vec res(2);
    res = recursive<10>(A, b);

    std::cout << res << std::endl;
}

In the following example D0 and D1 are types derived by the input given in main. Functions will be generated for Vec b and Mat A, but will also account for the resulting expression types from the recursive function. (Expression types A*b and A.ldlt().solve(b))

The goal is to make the compiler transform the function into a loop eventually, so alternative implementations are welcome.

Compiler error output (-std=c 17):

EigenRecursive.cpp:16:45: error: non-class, non-variable partial specialization recursive<N, D0, D1, std::enable_if<(N == 0), void> >’ is not allowed
   16 |                      const MatrixBase<D1> &b)

Edit: Implementing yosefs suggestion solved the partial specialization:

#include <iostream>
#include <Eigen/Dense>
using Mat = Eigen::MatrixXd;
using Vec = Eigen::VectorXd;
using Eigen::MatrixBase;

template <int N>
struct Recursive
{
template <typename D0, typename D1>
static inline Vec recursive(const MatrixBase<D0> &A,
                     const MatrixBase<D1> &b)
{
    return Recursive<N-1>::recursive(A, A * b);// : Recursive<N-1>::recursive(A, A.ldlt().solve(b));
}

};
template<>
struct Recursive<0>
{
template <typename D0, typename D1>
static inline Vec recursive(const MatrixBase<D0> &A,
                     const MatrixBase<D1> &b)
{
    return b;
}
};

int main()
{
    Mat A(2, 2);
    A(0, 0) = 10.0;
    A(1, 1) = 10.0;

    Vec b(2);
    b(0) = 1.0;
    b(1) = 5.0;

    Vec res(2);
    res = Recursive<10>::recursive(A, b);

    std::cout << res << std::endl;
}

The conditional operator introduce an exponentially increasing static compile time, which makes the current solution infeasible.

CodePudding user response:

You can often avoid:

error: non-class, non-variable partial specialization

by nesting whatever function you want to implement as a method in some dummy class.

For example say I want to implement foo() with the specialization:

template<int First, int Second>
void foo<First, Second>(){}

template<int First>
void foo<First, 0>(){}

int main(){
    foo<4,20>();
}

The simple example above will result with the 'non-class, non-variable partial specialization' error.

But our code can be replaced with the equivalent:

template <int First, int Second>
struct Foo{
    static void foo(){}
};

template <int First>
struct Foo<First, 0>{   
    static void foo(){}
};

int main()
{
    Foo<4,20>::foo();
}

And the second example compiles and runs with g 7.4.0.

CodePudding user response:

Yousefs partial specialization via structs solve the partial specialization problem, but there was still an exponentially increasing compile time due to the conditional operator. Replacing the conditional operator with template parameter logic using std::enable_if resolved this issue:

#include <iostream>
#include <Eigen/Dense>
using Mat = Eigen::MatrixXd;
using Vec = Eigen::VectorXd;
using Eigen::MatrixBase;

template <int N, class Enable = void>
struct Recursive
{
template <typename D0, typename D1>
static inline Vec recursive(const MatrixBase<D0> &A,
                     const MatrixBase<D1> &b)
{
    return Recursive<N-1>::recursive(A, A * b);
}

};

template <int N>
struct Recursive<N,class std::enable_if<(N%2)>> 
{
template <typename D0, typename D1>
static inline Vec recursive(const MatrixBase<D0> &A,
                     const MatrixBase<D1> &b)
{
    return Recursive<N-1>::recursive(A, A.ldlt().solve(b));
}

};

template<class Enable>
struct Recursive<0, Enable>
{
template <typename D0, typename D1>
static inline Vec recursive(const MatrixBase<D0> &A,
                     const MatrixBase<D1> &b)
{
    return b;
}
};

int main()
{
    Mat A(2, 2);
    A(0, 0) = 10.0;
    A(1, 1) = 10.0;

    Vec b(2);
    b(0) = 1.0;
    b(1) = 5.0;

    Vec res(2);
    res = Recursive<20>::recursive(A, b);

    std::cout << res << std::endl;
}

CodePudding user response:

Since C 17 there is if constexpr which in my opinion makes recursing in templates much more readable, template code would then read something like this (not tested I don't have Eigen installed):

template<typename T>
struct  MatrixBase
{
};

template<unsigned int N, typename D0, typename D1>
inline auto recursive(const MatrixBase<D0>& a, const MatrixBase<D1>& b)
{
    if constexpr (N == 0)
    {
        return b;
    }
    else
    {
        if constexpr (N % 2 == 1)
        {
            return recursive<N-1>(b,a);
        }
        else
        {
            return recursive<N-1>(a, b); // modified my code for testing here!
        }
    }
}


int main()
{
    MatrixBase<int> a;
    MatrixBase<int> b;
    recursive<8>(a, b);
}
  • Related