I have a coastal oceanic dataset with coordinates (time, depth, lat, lon). Data for deep layers is masked out with nan close to the coast. I want to make a new dataset with coordinates (time, lat, lon) by selecting values along the ocean floor.
Currently I am doing this using dataset.bfill('depth').isel({'depth': 0})
, which backfills all the nans using the deepest data present, then slices off the deepest layer. This works, but is inefficient. bfill will update every variable, across all timesteps, to fill the missing values downwards.
The depth of the ocean floor does not change between timesteps or across variables. I would like to use this fact to make this operation more efficient. Assume I have a (lat, lon) array which contains an index into the depth coordinate, indicating where the ocean floor is at that coordinate. Making this index array is relatively easy, but I don't know how to then use it to select the right data.
Is there a way of using this (lat, lon) array of depth indices to select just the depth index that I am interested in, efficiently, across all variables and timesteps? i.e. something like:
>>> dataset
Dimensions: (t: 5, z: 5, y: 5, x: 5)
Coordinates:
time (t) datetime64 2022-02-08 ...
lon (x) int64 0 -1 -2 -3 -4
lat (y) int64 0 1 2 3 4
depth (z) float64 4.25 3.25 2.25 1.25 0.25
Dimensions without coordinates: z, y, x, t
Data variables:
temp (t, z, y, x) float64 0.0 nan nan nan nan nan ... 4.0 4.0 4.0 4.0 4.0
>>> depth_indices = compute_ocean_floor_index(
dataset, depth_variable='depth', coordinate_variables=['lon', 'lat'])
>>> depth_indices
array([[0, 1, 2, 3, 4],
[1, 1, 2, 3, 4],
[2, 2, 2, 3, 4],
[3, 3, 3, 3, 4],
[4, 4, 4, 4, 4]])
>>> dataset_floor = dataset.some_selector(depth_indices)
>>> dataset_floor
Dimensions: (t: 5, y: 5, x: 5)
Coordinates:
time (t) datetime64 2022-02-08 ...
lon (x) int64 0 -1 -2 -3 -4
lat (y) int64 0 1 2 3 4
Data variables:
temp (t, y, x) float64 0.0 1.0 2.0 3.0 4.0 1.0 ... 4.0 4.0 4.0 4.0 4.0
The current implementation passes the following test function. The new implementation I am after would pass the same test, without using bfill()
:
import numpy as np
import pandas as pd
import xarray as xr
from numpy.testing import assert_equal
from cemarray.operations import ocean_floor
def test_ocean_floor():
# Values will be a 3D cube of values, with a slice along the x-axis like
# y
# 44444
# 3333.
# d 222..
# 11...
# 0....
values = np.full((5, 5, 5, 5), fill_value=np.nan)
for i in range(5):
values[:, i, :i 1, :i 1] = i
temp = xr.DataArray(
data=values,
dims=['t', 'z', 'y', 'x'],
)
dataset = xr.Dataset(
data_vars={"temp": temp},
coords={
'time': (['t'], pd.date_range('2022-02-08', periods=5)),
'lon': (['x'], -np.arange(5)),
'lat': (['y'], np.arange(5)),
'depth': (['z'], 4.25 - np.arange(5), {'positive': 'down'}),
}
)
floor_dataset = ocean_floor(dataset, ['depth'])
assert floor_dataset.dims == {
't': 5,
'x': 5,
'y': 5,
}
assert set(floor_dataset.coords.keys()) == {'time', 'lon', 'lat'}
# We should see values for the deepest layer that has a value there
expected_values = [
[0, 1, 2, 3, 4],
[1, 1, 2, 3, 4],
[2, 2, 2, 3, 4],
[3, 3, 3, 3, 4],
[4, 4, 4, 4, 4],
]
assert_equal(
floor_dataset['temp'].values,
np.array([expected_values] * 5)
)
CodePudding user response:
Yep! xarray's Advanced Indexing works in many dimensions too!
Create the indexer "some_selector" you thought up (which totally does exist!) by indexing using a DataArray with the values equal to the coordinates you'd like to select, and the dimensions/coordinates matching the target result. In this case, you want a DataArray of z
indexed by x, y
:
>>> depth_indices = compute_ocean_floor_index(
dataset, depth_variable='depth', coordinate_variables=['lon', 'lat'])
>>> depth_indices
array([[0, 1, 2, 3, 4],
[1, 1, 2, 3, 4],
[2, 2, 2, 3, 4],
[3, 3, 3, 3, 4],
[4, 4, 4, 4, 4]])
>>> selector = xr.DataArray(depth_indices, dims=('y', 'x'))
This selector can now be used to pull out the z levels for each (x, y)
pair, ignoring t
:
>>> dataset_floor = dataset.isel(z=selector)
>>> dataset_floor
Dimensions: (t: 5, y: 5, x: 5)
Coordinates:
time (t) datetime64 2022-02-08 ...
lon (x) int64 0 -1 -2 -3 -4
lat (y) int64 0 1 2 3 4
Data variables:
temp (t, y, x) float64 0.0 1.0 2.0 3.0 4.0 1.0 ... 4.0 4.0 4.0 4.0 4.0