I've been dangling around with a small bare-bone multi-dimensional-tensor-array implementation; and have encountered an issue with std::make_index_sequence
in combination with variadic template arguments. Give the following stripped implementation:
template <class scalar_t, std::size_t ... Dims>
class tensor {
public:
using value_type = scalar_t;
using size_type = std::size_t;
using index_type = std::size_t;
using container_type = std::array<value_type, (Dims * ...)>;
using shape_type = std::array<size_type, sizeof...(Dims)>;
using stride_type = std::array<index_type, sizeof...(Dims)>;
constexpr static inline size_type size = (Dims * ...);
constexpr static inline size_type rank = sizeof...(Dims);
constexpr static inline shape_type shape = /* omitted for brevity */;
constexpr static inline stride_type stride = /* omitted for brevity */;
/* constructors omitted */
private:
container_type m_data{};
I now wanted to have a variadic template on the call-operator operator()()
to access elements within the private container:
template <class ... Indices, std::enable_if_t<sizeof...(Indices) == rank, int> = 0>
[[nodiscard]] constexpr value_type &operator()(Indices && ... index) noexcept {
const index_type data_index = resolve_index(std::make_index_sequence<rank>(), std::forward<Indices>(index)...);
return m_data[data_index];
}
The idea is fairly straight forward; generate a index_sequence
for the rank
of the tensor, then forward the request to a private helper resolve_index
that resolves the index based on the stride-memory layout. Here is a working solution:
WORKING SOLUTION
template <class Indices>
[[nodiscard]] constexpr index_type resolve_index(const std::size_t axis, Indices && index) noexcept {
return index * strides[axis];
}
template <std::size_t ... Axes, class ... Indices>
[[nodiscard]] constexpr index_type resolve_index(std::index_sequence<Axes...>, Indices && ... index) noexcept {
return (resolve_index(Axes, std::forward<Indices>(index)) ...);
}
resolve_index
is overloaded. I am certain the compiler can expand the fold-expression at compile-time, however, each single invocation of resolve_index(std::size_t, Indices&&)
executes at runtime (which is fine).
My pet peeve with this solution is that I normally prefer to use if constexpr (...)
whenever possible to eliminate simple functions overloads as above; specifically if the return of the function doesn't have to be automatically deduced by auto
or decltype(auto)
.
Therefore I'd like to write something like the following:
template <class ... Axes, class ... Indices>
[[nodiscard]] constexpr index_type resolve_index(Axes && ... axis, Indices && ... index) noexcept {
if constexpr (sizeof...(Indices) == 1)
return (index ...) * strides[(axis ...)];
else
return (resolve_index(std::forward<Axes>(axis), std::forward<Indices>(index)) ...);
}
Which unfortunately errors out at compile time with:
error: mismatched argument pack lengths while expanding ‘((tecra::tensor<scalar_t, Dims>*)this)->tecra::tensor<scalar_t, Dims>::resolve_index(forward<Axes>(axis), forward<Indices>(index))’
108 | return (resolve_index(std::forward<Axes>(axis), std::forward<Indices>(index)) ...);
| ^
Where did I go wrong? Here is a working godbolt example: https://godbolt.org/z/qsY51n8f7 (feel free to ignore the internal
stuff). Thanks for anyone looking into this!
CodePudding user response:
template <class ... Axes, class ... Indices>
[[nodiscard]] constexpr index_type
resolve_index(Axes&& ... axis, Indices&& ... index) noexcept;
has 2 issues:
Axes&&...
is non deducible (not last parameter).Axes&&...
is eitherstd::size_t
orstd::index_sequence<Is...>
(so it should just beAxe&&
) but you cannot expand according toIs
without an helper (function or lambda).
It would be something like:
template <class Axe, class ... Indices>
[[nodiscard]] constexpr index_type resolve_index(Axe axe, Indices && ... index) noexcept {
if constexpr (std::is_same_v<Axe, std::size_t>) {
static_assert(sizeof...(Indices) == 1);
return (index ...) * strides[axe];
} else {
return [&]<std::size_t...Is>(std::index_sequence<Is...>){
static_assert(sizeof...(Indices) == sizeof...(Is));
return (resolve_index(Is, std::forward<Indices>(index)) ...);
}(axe);
}
}
I suggest to add the sequence directly in the class parameter:
template <class scalar_t, typename seq_dim, std::size_t ... Dims>
class tensor_impl;
template <class scalar_t, std::size_t... Is, std::size_t ... Dims>
class tensor_impl<scalar_t, std::index_sequence<Is...>, Dims...>
{
// You might directly use Is...
// simplifying your interface (you might get rid of some template)
// ...
};
template <class scalar_t, std::size_t ... Dims>
using tensor = tensor_impl<scalar_t, std::index_sequence_for<Dims...>, Dims...>;