For example in 3 dimensions, I would normally do something like
vector<vector<vector<T>>> v(x, vector<vector<T>>(y, vector<T>(z, val)));
However this gets tedious for complex types and in large dimensions. Is it possible to define a type, say, tensor
, whose usage would be like so:
tensor<T> t(x, y, z, val1);
t[i][j][k] = val2;
CodePudding user response:
It's possible with template metaprogramming.
Define a vector NVector
template<int D, typename T>
struct NVector : public vector<NVector<D - 1, T>> {
template<typename... Args>
NVector(int n = 0, Args... args) : vector<NVector<D - 1, T>>(n, NVector<D - 1, T>(args...)) {
}
};
template<typename T>
struct NVector<1, T> : public vector<T> {
NVector(int n = 0, const T &val = T()) : vector<T>(n, val) {
}
};
You can use it like this
const int n = 5, m = 5, k = 5;
NVector<3, int> a(n, m, k, 0);
cout << a[0][0][0] << '\n';
I think it's clear how it can be used. Let's still say NVector<# of dimensions, type> a(lengths of each dimension separated by coma (optional)..., default value (optional))
.
CodePudding user response:
The other answer shows a good way of making a vector of vectors with template metaprogramming. If you want a multidimensional array data structure with fewer allocations and contiguous storage underneath, here is an example of how to achieve that with a an NDArray template class wrapping access to an underlying vector. This could be extended to define operator=
, copy operators, debug bounds checking per dimension, etc for extra convenience.
NDArray.h
#pragma once
#include <array>
#include <vector>
template<int N, typename ValueType>
class NDArray {
public:
template<typename... Args>
NDArray(Args... args)
: dims({{static_cast<int>(args)...}}),
offsets(compute_offsets(dims)),
data(compute_size(dims), ValueType{})
{
static_assert(sizeof...(args) == N,
"Incorrect number of NDArray dimension arguments");
}
void fill(ValueType val) {
std::fill(data.begin(), data.end(), val);
}
template<typename... Args>
inline ValueType operator()(Args... args) const {
static_assert(sizeof...(args) == N,
"Incorrect number of NDArray index arguments");
return data[calc_index({ {static_cast<int>(args)...} })];
}
template<typename... Args>
inline ValueType& operator()(Args... args) {
static_assert(sizeof...(args) == N,
"Incorrect number of NDArray index arguments");
return data[calc_index({ {static_cast<int>(args)...} })];
}
int length(int axis) const { return dims[axis]; }
const int num_dims = N;
private:
static std::array<int, N> compute_offsets(const std::array<int, N>& dims) {
std::array<int, N> offsets{};
offsets[0] = 1;
for (int i = 1; i < N; i) {
offsets[i] = offsets[i - 1] * dims[i - 1];
}
return offsets;
}
static int compute_size(const std::array<int, N>& dims) {
int size = 1;
for (auto&& d : dims) size *= d;
return size;
}
inline int calc_index(const std::array<int, N>& indices) const {
int idx = 0;
for (int i = 0; i < N; i) idx = offsets[i] * indices[i];
return idx;
}
const std::array<int, N> dims;
const std::array<int, N> offsets;
std::vector<ValueType> data;
};
This overrides the operator()
with the correct number of arguments, and won't compile if the wrong number of arguments is given. Some example use
using Array2D = NDArray<2,double>;
using Array3D = NDArray<3,double>;
auto a = Array2D(3, 6);
a.fill(1.0);
a(2, 4) = 2.0;
//a(2,4,4) will not compile
std::cout << "a = " << std::endl << a << std::endl;
//auto b = Array3D(4, 4); // will not compile
auto b = Array3D(4, 3, 2);
b.fill(-1.0);
b(0, 0, 0) = 4.0;
b(1, 1, 1) = 2.0;
std::cout << "b = " << std::endl << b << std::endl;
(using helper output methods for 2D and 3D arrays)
std::ostream& operator<<(std::ostream& os, const Array2D& arr) {
for (int i = 0; i < arr.length(0); i) {
for (int j = 0; j < arr.length(1); j) {
os << arr(i,j) << " ";
}
os << std::endl;
}
return os;
}
std::ostream& operator<<(std::ostream& os, const Array3D& arr) {
for (int k = 0; k < arr.length(2); k) {
os << "array(:,:,"<<k<<") = " << std::endl;
for (int i = 0; i < arr.length(0); i) {
os << " ";
for (int j = 0; j < arr.length(1); j) {
os << arr(i, j, k) << " ";
}
os << std::endl;
}
os << std::endl;
}
return os;
}