{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Converting a function from 1D to ND - apply_along_axis\n", "\n", "We have a function that works along the time axis, but only for 1D data. We'd like to convert it to work on a full dataset." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "tags": [ "hide-input" ] }, "outputs": [], "source": [ "import xarray\n", "import numpy as np\n", "import scipy.integrate\n", "import dask.array" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1d function\n", "\n", "Our initial function, that works only on one dimension, is:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "def edof_1d(A):\n", " N = len(A)\n", " x = A - A.mean()\n", " c = np.correlate(x, x, 'full')\n", " c = c[N-1:]/(N-1-np.arange(0, N, 1))\n", " n = 0\n", " while (c[n] > 0) and (n < N/2):\n", " n = n+1\n", " T = scipy.integrate.trapz(c[:n])/c\n", " edof = N/(2*T)\n", " return edof" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Nd function\n", "\n", "In this case the function is calling numpy.correlate, which needs the whole time axis. The best way to parallelise a function is to re-write it to compute the whole grid at once, stepping over time. See the example {doc}`oned_to_nd_rewrite`\n", "\n", "A more general way to convert the 1d function to work on a nd dataset is to use [`dask.array.apply_along_axis`](https://docs.dask.org/en/latest/array-api.html#dask.array.apply_along_axis) (a dask-array version of [`numpy.apply_along_axis`](https://numpy.org/doc/stable/reference/generated/numpy.apply_along_axis.html)). This will process the chunks in parallel, but will be a bit slower than the re-write method." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "def edof(A):\n", " axis = 0 # Generally 'time' is axis zero, or use A.get_axis_num('time')\n", " edof = dask.array.apply_along_axis(edof_1d, axis, A)\n", " return np.nanmean(edof)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Sample data\n", "\n", "To test performance, I'll use ERA5 data of approximately the target size. Note the horizontal chunks, we're working along the time axis so horizontal chunking should be preferred" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "path = \"/g/data/rt52/era5/single-levels/reanalysis/2t/2001/2t_era5_oper_sfc_200101*.nc\"\n", "ds = xarray.open_mfdataset(\n", " path, combine=\"nested\", concat_dim=\"time\", chunks={'latitude': 91, 'longitude': 180}\n", ")" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "
```<xarray.DataArray 't2m' (time: 744, latitude: 721, longitude: 1440)>\n",
"dask.array<open_dataset-ec1e63fd41cfe92c2eb0ba0648acc61at2m, shape=(744, 721, 1440), dtype=float32, chunksize=(744, 91, 180), chunktype=numpy.ndarray>\n",
"Coordinates:\n",
"  * longitude  (longitude) float32 -180.0 -179.8 -179.5 ... 179.2 179.5 179.8\n",
"  * latitude   (latitude) float32 90.0 89.75 89.5 89.25 ... -89.5 -89.75 -90.0\n",
"  * time       (time) datetime64[ns] 2001-01-01 ... 2001-01-31T23:00:00\n",
"Attributes:\n",
"    units:      K\n",
"    long_name:  2 metre temperature```
" ], "text/plain": [ "\n", "dask.array\n", "Coordinates:\n", " * longitude (longitude) float32 -180.0 -179.8 -179.5 ... 179.2 179.5 179.8\n", " * latitude (latitude) float32 90.0 89.75 89.5 89.25 ... -89.5 -89.75 -90.0\n", " * time (time) datetime64[ns] 2001-01-01 ... 2001-01-31T23:00:00\n", "Attributes:\n", " units: K\n", " long_name: 2 metre temperature" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ds.t2m" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Start a dask client\n", "\n", "I'm running on Gadi, [`climtas.nci.GadiClient()`](https://climtas.readthedocs.io/en/latest/nci.html#climtas.nci.GadiClient) gets the available resources from PBS, or you can start up a client manually with [`dask.distributed.Client()`](https://distributed.dask.org/en/latest/quickstart.html#setup-dask-distributed-the-easy-way)." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "\n", "\n", "\n", "\n", "
\n", "

\n", "
\n", "
\n", "

### Cluster

\n", "
\n", "
• Workers: 1
• \n", "
• Cores: 1
• \n", "
• Memory: 4.29 GB
• \n", "
\n", "
" ], "text/plain": [ "" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import climtas.nci\n", "climtas.nci.GadiClient()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Running the function\n", "\n", "Since we're using a Dask function the values aren't computed automatically. Instead we need to call `.compute()` on the result to get the value" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ ":5: RuntimeWarning: invalid value encountered in true_divide\n", " c = c[N-1:]/(N-1-np.arange(0, N, 1))\n" ] }, { "data": { "text/html": [ "\n", "\n", "\n", "\n", "\n", "
\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
 Array Chunk 8 B 8 B () () 278 Tasks 1 Chunks float64 numpy.ndarray
\n", "
\n", "\n", "
" ], "text/plain": [ "dask.array" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "edof(ds.t2m)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ ":5: RuntimeWarning: invalid value encountered in true_divide\n", " c = c[N-1:]/(N-1-np.arange(0, N, 1))\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 3.15 s, sys: 360 ms, total: 3.51 s\n", "Wall time: 3min 59s\n" ] }, { "data": { "text/plain": [ "27.634806731994875" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%%time\n", "\n", "edof(ds.t2m).compute()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python [conda env:analysis3]", "language": "python", "name": "conda-env-analysis3-py" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.6" } }, "nbformat": 4, "nbformat_minor": 4 }