Is it possible to create a pytree (for JAX) based upon an xarray.DataArray
?
The key issue seems to be that the data
attribute of an xarray.DataArray
is constructed as a numpy.ndarray
view of the input array (such as DeviceArray
)
import jax.numpy as jnp
import xarray as xr
da = xr.DataArray(
data=jnp.array([2.0,3.0]),
dims=("var"),
coords={"var": ["A","B"]},
)
>>> type(da.data)
numpy.ndarray
Flattening/Unflattening the DataArray
into a pytree is relatively straight-forward (all the attributes are aux except data
), but I don't know how to retrieve the DeviceArray
, or even assign to it (I can't use at[.].set(.)
here).
The alternative approach of constructing a container class (with DeviceArray
member) requires that all the relevant functionality of xarray be manually implemented. For a single feature, such as labelled indexing, this is possible but redundant.
CodePudding user response:
You may be able to do what you want by registering the xarray type as a custom PyTree node following the examples in the documentation.
For example, it might look like this:
import jax.numpy as jnp
from jax import tree_util
import xarray as xr
tree_util.register_pytree_node(
xr.DataArray,
lambda x: ((x.data,), {"dims": x.dims, "coords": x.coords}),
lambda kwds, args: xr.DataArray(*args, **kwds)
)
da = xr.DataArray(
data=jnp.array([2.0,3.0]),
dims=("var"),
coords={"var": ["A","B"]},
)
print(da)
# <xarray.DataArray (var: 2)>
# array([2., 3.], dtype=float32)
# Coordinates:
# * var (var) <U1 'A' 'B'
data, tree = tree_util.tree_flatten(da)
print(data)
# [array([2., 3.], dtype=float32)]
da_reconstructed = tree_util.tree_unflatten(tree, data)
print(da_reconstructed)
# <xarray.DataArray (var: 2)>
# array([2., 3.], dtype=float32)
# Coordinates:
# * var (var) <U1 'A' 'B'
I don't think there's much of a possibility that this will work as intended in any but the simplest of cases: for example, JAX transformations are restricted to functional, non-side-effecting code, and JAX arrays are immutable. xarray's operations in general violate both these constraints.
Another issue: you'd have to be careful to define your flatten and unflatten functions in a way that will ensure that all data and metadata is properly serialized, and if you're hoping to use this within any JAX function, be aware that you might run into issues with xarray's validation of inputs; see Custom PyTrees and Initialization for more information.