Home > Enterprise >  Merge dataclasses in python
Merge dataclasses in python

Time:02-14

I have a dataclass like:

import dataclasses
import jax.numpy as jnp

@dataclasses.dataclass
class Metric:
    score1: jnp.ndarray
    score2: jnp.ndarray
    score3: jnp.ndarray

In my code, I create multiple instances of it, is there an easy way to merge two of them attribute per attribute? For example if I have:

a = Metric(score1=jnp.array([10,10,10]), score2=jnp.array([20,20,20]), score3=jnp.array([30,30,30]))
b = Metric(score1=jnp.array([10,10,10]), score2=jnp.array([20,20,20]), score3=jnp.array([30,30,30]))

I'd like to merge them such as having a single Metric containing:

score1=jnp.array([10,10,10,10,10,10]), score2=jnp.array([20,20,20,20,20,20]) and score3=jnp.array([30,30,30,30,30,30])

CodePudding user response:

The easiest thing is probably just to define a method:

import dataclasses
import jax.numpy as jnp


@dataclasses.dataclass
class Metric:
    score1: jnp.ndarray
    score2: jnp.ndarray
    score3: jnp.ndarray

    def concatenate(self, other):
        return Metric(
            jnp.concatenate((self.score1, other.score1)),
            jnp.concatenate((self.score2, other.score2)),
            jnp.concatenate((self.score3, other.score3)),
        )

and then just do a.concatenate(b). You could also instead call the method __add__, which would make it possible just to use a b. This is neater, but could potentially be confused with element-wise addition.

CodePudding user response:

It is possible to do so in a "jax-centric" manner by registering the class Metric as a pytree_node. google/flax, a neural network library built on top of jax, provides the flax.struct.dataclass helper to do so. Once registered, you can use the jax.tree_util package to manipulate Metric instances:

from flax.struct import dataclass as flax_dataclass
from jax.tree_util import tree_multimap
import jax.numpy as jnp

@flax_dataclass
class Metric:
    score1: jnp.ndarray
    score2: jnp.ndarray
    score3: jnp.ndarray

a = Metric(score1=jnp.array([10,10,10]), score2=jnp.array([20,20,20]), score3=jnp.array([30,30,30]))
b = Metric(score1=jnp.array([10,10,10]), score2=jnp.array([20,20,20]), score3=jnp.array([30,30,30]))

tree_multimap(lambda x, y: jnp.concatenate([x, y]), a, b)

Gives:

Metric(score1=DeviceArray([10, 10, 10, 10, 10, 10], dtype=int32), score2=DeviceArray([20, 20, 20, 20, 20, 20], dtype=int32), score3=DeviceArray([30, 30, 30, 30, 30, 30], dtype=int32))
  • Related