Converting a function from 1D to ND - apply_along_axis

We have a function that works along the time axis, but only for 1D data. We’d like to convert it to work on a full dataset.

import xarray
import numpy as np
import scipy.integrate
import dask.array

1d function

Our initial function, that works only on one dimension, is:

def edof_1d(A):
    N = len(A)
    x = A - A.mean()
    c = np.correlate(x, x, 'full')
    c = c[N-1:]/(N-1-np.arange(0, N, 1))
    n = 0
    while (c[n] > 0) and (n < N/2):
        n = n+1
    T = scipy.integrate.trapz(c[:n])/c[0]
    edof = N/(2*T)
    return edof

Nd function

In this case the function is calling numpy.correlate, which needs the whole time axis. The best way to parallelise a function is to re-write it to compute the whole grid at once, stepping over time. See the example Converting a function from 1D to ND - slice method

A more general way to convert the 1d function to work on a nd dataset is to use dask.array.apply_along_axis (a dask-array version of numpy.apply_along_axis). This will process the chunks in parallel, but will be a bit slower than the re-write method.

def edof(A):
    axis = 0 # Generally 'time' is axis zero, or use A.get_axis_num('time')
    edof = dask.array.apply_along_axis(edof_1d, axis, A)
    return np.nanmean(edof)

Sample data

To test performance, I’ll use ERA5 data of approximately the target size. Note the horizontal chunks, we’re working along the time axis so horizontal chunking should be preferred

path = "/g/data/rt52/era5/single-levels/reanalysis/2t/2001/2t_era5_oper_sfc_200101*.nc"
ds = xarray.open_mfdataset(
    path, combine="nested", concat_dim="time", chunks={'latitude': 91, 'longitude': 180}
)
ds.t2m
<xarray.DataArray 't2m' (time: 744, latitude: 721, longitude: 1440)>
dask.array<open_dataset-ec1e63fd41cfe92c2eb0ba0648acc61at2m, shape=(744, 721, 1440), dtype=float32, chunksize=(744, 91, 180), chunktype=numpy.ndarray>
Coordinates:
  * longitude  (longitude) float32 -180.0 -179.8 -179.5 ... 179.2 179.5 179.8
  * latitude   (latitude) float32 90.0 89.75 89.5 89.25 ... -89.5 -89.75 -90.0
  * time       (time) datetime64[ns] 2001-01-01 ... 2001-01-31T23:00:00
Attributes:
    units:      K
    long_name:  2 metre temperature

Start a dask client

I’m running on Gadi, climtas.nci.GadiClient() gets the available resources from PBS, or you can start up a client manually with dask.distributed.Client().

import climtas.nci
climtas.nci.GadiClient()

Client

Cluster

  • Workers: 1
  • Cores: 1
  • Memory: 4.29 GB

Running the function

Since we’re using a Dask function the values aren’t computed automatically. Instead we need to call .compute() on the result to get the value

edof(ds.t2m)
<ipython-input-2-d6dd81447a9b>:5: RuntimeWarning: invalid value encountered in true_divide
  c = c[N-1:]/(N-1-np.arange(0, N, 1))
Array Chunk
Bytes 8 B 8 B
Shape () ()
Count 278 Tasks 1 Chunks
Type float64 numpy.ndarray
%%time

edof(ds.t2m).compute()
<ipython-input-2-d6dd81447a9b>:5: RuntimeWarning: invalid value encountered in true_divide
  c = c[N-1:]/(N-1-np.arange(0, N, 1))
CPU times: user 3.15 s, sys: 360 ms, total: 3.51 s
Wall time: 3min 59s
27.634806731994875