I have a dataclass like:
import dataclasses
import jax.numpy as jnp
@dataclasses.dataclass
class Metric:
score1: jnp.ndarray
score2: jnp.ndarray
score3: jnp.ndarray
In my code, I create multiple instances of it, is there an easy way to merge two of them attribute per attribute? For example if I have:
a = Metric(score1=jnp.array([10,10,10]), score2=jnp.array([20,20,20]), score3=jnp.array([30,30,30]))
b = Metric(score1=jnp.array([10,10,10]), score2=jnp.array([20,20,20]), score3=jnp.array([30,30,30]))
I'd like to merge them such as having a single Metric containing:
score1=jnp.array([10,10,10,10,10,10]), score2=jnp.array([20,20,20,20,20,20]) and score3=jnp.array([30,30,30,30,30,30])
CodePudding user response:
The easiest thing is probably just to define a method:
import dataclasses
import jax.numpy as jnp
@dataclasses.dataclass
class Metric:
score1: jnp.ndarray
score2: jnp.ndarray
score3: jnp.ndarray
def concatenate(self, other):
return Metric(
jnp.concatenate((self.score1, other.score1)),
jnp.concatenate((self.score2, other.score2)),
jnp.concatenate((self.score3, other.score3)),
)
and then just do a.concatenate(b)
. You could also instead call the method __add__
, which would make it possible just to use a b
. This is neater, but could potentially be confused with element-wise addition.
CodePudding user response:
It is possible to do so in a "jax-centric" manner by registering the class Metric
as a pytree_node
. google/flax
, a neural network library built on top of jax
, provides the flax.struct.dataclass
helper to do so. Once registered, you can use the jax.tree_util
package to manipulate Metric
instances:
from flax.struct import dataclass as flax_dataclass
from jax.tree_util import tree_multimap
import jax.numpy as jnp
@flax_dataclass
class Metric:
score1: jnp.ndarray
score2: jnp.ndarray
score3: jnp.ndarray
a = Metric(score1=jnp.array([10,10,10]), score2=jnp.array([20,20,20]), score3=jnp.array([30,30,30]))
b = Metric(score1=jnp.array([10,10,10]), score2=jnp.array([20,20,20]), score3=jnp.array([30,30,30]))
tree_multimap(lambda x, y: jnp.concatenate([x, y]), a, b)
Gives:
Metric(score1=DeviceArray([10, 10, 10, 10, 10, 10], dtype=int32), score2=DeviceArray([20, 20, 20, 20, 20, 20], dtype=int32), score3=DeviceArray([30, 30, 30, 30, 30, 30], dtype=int32))