Home > Software engineering >  multi-dimensional array type as template argument
multi-dimensional array type as template argument

Time:07-27

The kokkos scientific computing library implements the multi-dimensional array (which they call View) such that it knows which dimensions are fixed at compile time and which dimensions are runtime variable. For example, to create a 3-dimensional array data of shape (N0, N1, N2), you can do one of the followings:

View<double***> data(N0, N1, N2); // 3 run, 0 compile
View<double**[N2]> data(N0, N1);  // 2 run, 1 compile
View<double*[N1][N2]> data(N0);   // 1 run, 2 compile
View<double[N0][N1][N2]> data();  // 0 run, 3 compile

If I want to implement something like View, how should I handle the the template arguments? i.e. How should I count number of asterisks and square brackets (and get the numbers in those square brackets) of the template argument in my implementation?

CodePudding user response:

This question can be split into two parts.

  • Get the dimensions from the input template parameter T.
  • Constructor accepts different numbers of values as run-time dimensions.

Let's look at it one by one.


To solve the first problem, a non-type typelist is used to represent the dimensions.

template <size_t... Ns>
struct Dimension {
    template <size_t N>
    using prepend = Dimension<N, Ns...>;
};

0 means the dimension is determined at run-time. Now we construct a template class hoping to decompose its template parameter into the dimensions we want.

template <typename>
struct Analysis;

// Analysis<int**> -> Dimension<0, 0>
// Analysis<int[1][2] -> Dimension<1, 2>
// Analysis<int*[3]> -> Dimension<0, 3>

Using alias template nested inside specialization to decompose the pointers / [] layer by layer recursively. Compile-time dimensions and run-time dimensions are represented separately and joined together. Whenever meeting a *, prepend a 0 in the dynamic dimension. Whenever meeting a [N], prepend an N in the static dimension.

template <typename T>
struct Analysis {
    using sdim = Dimension<>; // static
    using ddim = Dimension<>; // dynamic
    using dim = Dimension<>;  // all
};

template <typename T, size_t N>
struct Analysis<T[N]> {
    using nested = Analysis<T>;
    using sdim = typename nested::sdim::template prepend<N>;
    using ddim = typename nested::ddim;
    using dim = join_t<ddim, sdim>;
};

template <typename T>
struct Analysis<T*> {
    using nested = Analysis<T>;
    using sdim = typename nested::sdim;
    using ddim = typename nested::ddim::template prepend<0>;
    using dim = join_t<ddim, sdim>;
};

T[] is similar to T*, not shown here. Now we have,

static_assert(std::is_same_v<
            Analysis<int[1][2][3]>::dim,
            Dimension<1, 2, 3>>);
static_assert(std::is_same_v<
            Analysis<int***>::dim,
            Dimension<0, 0, 0>>);
static_assert(std::is_same_v<
            Analysis<int*[1][2]>::dim,
            Dimension<0, 1, 2>>);

Demo


Since we've got the dimensions, constructing a View-like thing is simple. Its constructor accepts a bunch of parameters with a default value as run-time dimensions.

template <typename T>
struct View {
    using dim = typename Analysis<T>::dim;
    View(size_t dim0 = -1, size_t dim1 = -1, size_t dim2 = -1) { // you could write more
            if (get_v<dim, 0> == 0 && dim0 != -1) {
                // run-time
            }
            if (get_v<dim, 1> == 0 && dim1 != -1) {}
            if (get_v<dim, 2> == 0 && dim2 != -1) {}

        }
};

Demo

  • Related