Home > Mobile >  Making a function that can accept different types of templated class inputs
Making a function that can accept different types of templated class inputs

Time:09-25

I have an application where I need to reduce the memory usage of my matrix classes. Many of my matrices are symmetric, and so should only require N(N 1)/2 memory instead of N2. I can't simply have a MatrixSym class that derives from my Matrix class because a child class always uses more memory than the base class. I've defined a MatrixAbstract class from which both classes inherit, however I am having trouble getting a matrix multiplication function to work on both of them.

My question

How can I define a C = multiply(A,B) function that accepts either a Matrix or MatrixSym object for any input/output? My minimal working example is not so minimal, because it contains four definitions of multiply that all use the same code and should be combined into one.

Requirements

  • Static allocation only
  • Must use std::array for matrix elements
  • Minimal repeated code: I have many matrix operations and the only thing different about MatrixSym is the way it stores and accesses elements from the underlying matVals array.
#include <array> 
#include <iostream> 

template<unsigned int ROWS, unsigned int COLS, unsigned int NUMEL>
class MatrixAbstract
{
private:
    static constexpr unsigned int numel = NUMEL; 
    std::array<double, NUMEL> matVals;   // The array of matrix elements
public:
    MatrixSuperclass(){}
    virtual unsigned int index_from_rc(const unsigned int& row, const unsigned int& col) const = 0;
    // get the value at a given row and column
    double get_value(const int& row, const int& col) const
    {
        return this->matVals[this->index_from_rc(row,col)];
    }
    // set the value, given the row and column
    void set_value(const int& row, const int& col, double value)
    {
        this->matVals[this->index_from_rc(row,col)] = value;
    }

};

template<unsigned int ROWS, unsigned int COLS, unsigned int NUMEL = ROWS*COLS>
class Matrix : public MatrixSuperclass<ROWS, COLS, NUMEL>
{
public:
    Matrix(){}
    // get the linear index in matVals corresponding to a row,column input
    unsigned int index_from_rc(const unsigned int& row, const unsigned int& col) const {
        return row*COLS   col; 
    }
};

template<unsigned int ROWS, unsigned int COLS = ROWS, unsigned int NUMEL = ROWS*(ROWS 1)/2> 
class MatrixSym : public MatrixSuperclass<ROWS, COLS, NUMEL>
{
public:
    MatrixSym(){}
    // get the linear index in matVals corresponding to a row,column input (Symmetric matrix)
    unsigned int index_from_rc(const unsigned int& row, const unsigned int& col) const {
        unsigned int z;
        return ( ( z = ( row < col ? col : row ) ) * ( z   1 ) >> 1 )   ( col < row ? col : row ) ;
    }
};

// THE FOLLOWING FOUR FUNCTIONS ALL USE THE EXACT SAME CODE, ONLY INPUT/OUTPUT TYPES CHANGE

// Multiply a Matrix and Matrix and output a Matrix
template<unsigned int ROWS, unsigned int COLS, unsigned int INNER>
Matrix<ROWS,COLS> multiply (Matrix<ROWS,INNER>& inMatrix1, Matrix<INNER,COLS>& inMatrix2) {
    Matrix<ROWS,COLS> outMatrix;
    for (unsigned int r = 0; r < ROWS; r  ) {
        for (unsigned int c = 0; c < COLS; c  ) {
            double val = 0.0;
            for (unsigned int rc = 0; rc < INNER; rc  ) {
                val  = inMatrix1.get_value(r,rc)*inMatrix2.get_value(rc,c);
            }
            outMatrix.set_value(r,c,val);
        }
    }
    return outMatrix;
}

// Multiply a Matrix and MatrixSym and output a Matrix
template<unsigned int ROWS, unsigned int COLS, unsigned int INNER>
Matrix<ROWS,COLS> multiply (Matrix<ROWS,INNER>& inMatrix1, MatrixSym<INNER,COLS>& inMatrix2) {
    Matrix<ROWS,COLS> outMatrix;
    for (unsigned int r = 0; r < ROWS; r  ) {
        for (unsigned int c = 0; c < COLS; c  ) {
            double val = 0.0;
            for (unsigned int rc = 0; rc < INNER; rc  ) {
                val  = inMatrix1.get_value(r,rc)*inMatrix2.get_value(rc,c);
            }
            outMatrix.set_value(r,c,val);
        }
    }
    return outMatrix;
}

// Multiply a MatrixSym and Matrix and output a Matrix
template<unsigned int ROWS, unsigned int COLS, unsigned int INNER>
Matrix<ROWS,COLS> multiply (MatrixSym<ROWS,INNER>& inMatrix1, Matrix<INNER,COLS>& inMatrix2) {
    //MatrixSym<ROWS,COLS> outMatrixSym;
    Matrix<ROWS,COLS> outMatrix;
    for (unsigned int r = 0; r < ROWS; r  ) {
        for (unsigned int c = 0; c < COLS; c  ) {
            double val = 0.0;
            for (unsigned int rc = 0; rc < INNER; rc  ) {
                val  = inMatrix1.get_value(r,rc)*inMatrix2.get_value(rc,c);
            }
            outMatrix.set_value(r,c,val);
        }
    }
    return outMatrix;
}

// Multiply a MatrixSym and MatrixSym and output a MatrixSym
template<unsigned int ROWS, unsigned int COLS, unsigned int INNER>
MatrixSym<ROWS,COLS> multiply (MatrixSym<ROWS,INNER>& inMatrix1, MatrixSym<INNER,COLS>& inMatrix2) {
    //MatrixSym<ROWS,COLS> outMatrixSym;
    MatrixSym<ROWS,COLS> outMatrix;
    for (unsigned int r = 0; r < ROWS; r  ) {
        for (unsigned int c = 0; c < COLS; c  ) {
            double val = 0.0;
            for (unsigned int rc = 0; rc < INNER; rc  ) {
                val  = inMatrix1.get_value(r,rc)*inMatrix2.get_value(rc,c);
            }
            outMatrix.set_value(r,c,val);
        }
    }
    return outMatrix;
}

int main()
{
    Matrix<3,3> A;
    MatrixSym<3> S;
    Matrix<3,3> AtimesA = multiply(A,A);
    Matrix<3,3> AtimesS = multiply(A,S);
    Matrix<3,3> StimesA = multiply(S,A);
    MatrixSym<3> StimesS = multiply(S,S);
    
    // Make sure that symmetric matrix S is indeed smaller than A
    std::cout << "sizeof(A)/sizeof(double) = " << sizeof(A)/sizeof(double) << std::endl;
    std::cout << "sizeof(S)/sizeof(double) = " << sizeof(S)/sizeof(double) << std::endl;
    return 0;
}

Outputs:

sizeof(A)/sizeof(double) = 9
sizeof(S)/sizeof(double) = 15

What I've tried

  • If I try to make the function use MatrixAbstract as arguments I need to play around with the template parameters, since I can't have NUMEL as a template parameter. Further, I can't instantiate a MatrixAbstract for the return value.
  • I can't figure out how to template the function for either Matrix or MatrixSym. My hunch is that this is the way to solve it, but I don't understand how to template the function for a templated class input, but with still being able to use ROWS and COLS as template arguments.
template<class InMatrix1Type, class InMatrix2Type, class OutMatrixType, unsigned int ROWS, unsigned int COLS, unsigned int INNER>
OutMatrixType<ROWS,COLS> multiply (InMatrix1Type<ROWS,INNER>& inMatrix1, InMatrix2Type<INNER,COLS>& inMatrix2)

gives me a compiler error starting with

error: ‘OutMatrixType’ is not a template
 OutMatrixType<ROWS,COLS> multiply (InMatrix1Type<ROWS,INNER>& inMatrix1, InMatrix2Type<INNER,COLS>& inMatrix2)
 ^~~~~~~~~~~~~

CodePudding user response:

So I removed a lot of noise in your original question by removing specific code relevant to matrices and tried to boil it down to just what was being asked.

What I did here was I stuck with the base class and instead allow the user to specify the return type they want for the multiply function.

I haven't heavily tested this but it appears to do what you're after.

To reduce the number of non type template parameters on the multiply function I included some getters for the row, col, and numel. You can of course scrap those and go back to how you had it before, but those member functions will allow you to assert that the passed in parameters are correct.

All that being said, as @Pete Becker mentioned, you could also accomplish this without inheritance here. Read his comment for further information.

This isn't a complete example but may help you with the final solution.

class MatrixBase {
public: 
    virtual double get_value( int row, int column ) const = 0;
    virtual void set_value( int row, int column, double value ) const = 0;
    virtual std::uint32_t get_rows( ) const = 0;
    virtual std::uint32_t get_cols( ) const = 0;
    virtual std::uint32_t get_numel( ) const = 0;
    virtual ~MatrixBase( ) = default;
};

template<std::uint32_t Rows, std::uint32_t Cols, std::uint32_t Numel>
class Matrix : public MatrixBase {
public:
    double get_value( int row, int column ) const override { 
        return 1.0; 
    }

    void set_value( int row, int column, double value ) const override { }

    std::uint32_t get_rows( ) const override {
        return Rows;
    };

    std::uint32_t get_cols( ) const override {
        return Cols;
    }

    std::uint32_t get_numel( ) const override {
        return Numel;
    }

private:
    std::array<double, Numel> values_{ };
};

template<std::uint32_t Rows, std::uint32_t Cols, std::uint32_t Numel>
class MatrixSum : public MatrixBase {
public:
    double get_value( int row, int column ) const override {
        return 1.0;
    }

    void set_value( int row, int column, double value ) const override { }

    std::uint32_t get_rows( ) const override {
        return Rows;
    };

    std::uint32_t get_cols( ) const override {
        return Cols;
    }

    std::uint32_t get_numel( ) const override {
        return Numel;
    }

private:
    std::array<double, Numel> values_{ };
};


template<typename T, std::uint32_t Inner>
static T multiply( const MatrixBase& m1, const MatrixBase& m2 ) {
    static_assert( std::is_base_of_v<MatrixBase, T>, 
        "Return type must derive from MatrixBase" );

    static_assert( std::is_default_constructible_v<T>, 
        "Type must be default constructable" );

    T out{ };

    // Get the values.
    const auto m1_values{ m1.get_value( 1, 0 ) };
    const auto m2_values{ m2.get_value( 1, 0 ) };

    // Set the values on the new matrix.
    out.set_value( 1, 0, m1_values * m2_values );

    return out;
}

int main( ) {
    MatrixSum<10, 10, 10> matrix_sum{ };
    Matrix<10, 10, 10> matrix{ };

    auto m{ multiply<Matrix<10, 10, 10>, 10>( matrix_sum, matrix ) };
    auto ms{ multiply<MatrixSum<10, 10, 10>, 10>( matrix, matrix_sum ) };
}

Or, if you're using C 20 you can just define a concept with a requires clause where you can specify the interface a passed in type must have. So, using the definitions above that might look like this.

template<typename T>
concept Mat = std::is_default_constructible_v<T> && 
requires( T m, int row, int col, double val ) {
    { m.get_rows( ) } -> std::same_as<std::uint32_t>;
    { m.get_cols( ) } -> std::same_as<std::uint32_t>;
    { m.get_numel( ) } -> std::same_as<std::uint32_t>;
    { m.get_value( row, col ) } -> std::same_as<double>;
    m.set_value( row, col, val );
};

template<std::uint32_t Inner, Mat T1, Mat T2, Mat T3>
static T1 multiply( const T2& m1, const T3& m2 ) {
    T1 out{ };

    // Get the values.
    const auto m1_values{ m1.get_value( 1, 0 ) };
    const auto m2_values{ m2.get_value( 1, 0 ) };

    // Set the values on the new matrix.
    out.set_value( 1, 0, m1_values * m2_values );

    return out;
}

With this approach you can completely remove the MatrixBase class if you wanted and just constrain the multiply function to types that expose the desired functionality.

CodePudding user response:

The function multiply can be a template that takes two (possibly different) argument types, provided they support the appropriate interface. Like this:

template <unsigned ROWS, unsigned COLS>
struct Matrix {
    double get_value(int row, int col) const;
    void set_value(int row, int col, double value);
};

template <unsigned ROWS, unsigned COLS>
struct Symmetric_Matrix {
    double get_value(int row, int col) const;
    void set_value(int row, int col, double value);
};

template <unsigned ROWS, unsigned COLS, unsigned INNER,
    template <unsigned, unsigned> class M1,
    template <unsigned, unsigned> class M2>
Matrix<ROWS, COLS> multiply(const M1<ROWS, INNER>& m1,
    const M2<INNER, COLS>& m2) {
    Matrix<ROWS, COLS> result;
    for (unsigned r = 0; r < ROWS;   r)
        for (unsigned c = 0; c < COLS;   c) {
            double val = 0.0;
            for (unsigned rc = 0; rc < INNER;   rc)
                val  = m1.get_value(r, rc) * m2.get_value(rc, c);
        result.set_value(r, c, val);
        }
    return result;
}

M1 and M2 are template template parameters; that is, they're templates that are used as template arguments to multiply; the compiler will figure out their types, and the corresponding values of ROW, COL, and INNER when the function is called.

int main() {
    Matrix<3, 5> res;

    Matrix<3, 4> x1;
    Matrix<4, 5> x2;
    res = multiply(x1, x2);

    Symmetric_Matrix<3, 4> sx1;
    Symmetric_Matrix<4, 5> sx2;
    res = multiply(sx1, sx2);

    res = multiply(x1, sx2);
    res = multiply(sx1, x2);

    return 0;
}

Of course, real code would provide implementations for get_value and set_value, but the multiply code doesn't care how those are implemented, so having two different types will work just fine.

  • Related