You can run this notebook in a live session Binder or view it on Github.

d3eb25dd02594e28b0fabe95ddf64e0f

Xarray and Dask

This notebook demonstrates one of xarray’s most powerful features: the ability to wrap dask arrays and allow users to seamlessly execute analysis code in parallel.

By the end of this notebook, you will:

  1. Xarray DataArrays and Datasets are “dask collections” i.e. you can execute top-level dask functions such as dask.visualize(xarray_object)

  2. Learn that all xarray built-in operations can transparently use dask

  3. Learn that xarray provides tools to easily parallelize custom functions across blocks of dask-backed xarray objects.

Table of contents

  1. Reading data with Dask and Xarray

  2. Parallel/streaming/lazy computation using dask.array with Xarray

  3. Automatic parallelization with apply_ufunc and map_blocks

First lets do the necessary imports, start a dask cluster and test the dashboard

[1]:
import expectexception
import numpy as np
import xarray as xr

First lets set up a LocalCluster using dask.distributed.

You can use any kind of dask cluster. This step is completely independent of xarray.

[2]:
from dask.distributed import Client

client = Client()
client
[2]:

Client

Cluster

  • Workers: 2
  • Cores: 2
  • Memory: 8.36 GB

&#128070

Click the Dashboard link above. Or click the “Search” button in the dashboard.

Let’s test that the dashboard is working..

[3]:
import dask.array

dask.array.ones(
    (1000, 4), chunks=(2, 1)
).compute()  # should see activity in dashboard
[3]:
array([[1., 1., 1., 1.],
       [1., 1., 1., 1.],
       [1., 1., 1., 1.],
       ...,
       [1., 1., 1., 1.],
       [1., 1., 1., 1.],
       [1., 1., 1., 1.]])

Reading data with Dask and Xarray

The chunks argument to both open_dataset and open_mfdataset allow you to read datasets as dask arrays. See https://xarray.pydata.org/en/stable/dask.html#reading-and-writing-data for more details

[4]:
ds = xr.tutorial.open_dataset(
    "air_temperature",
    chunks={
        "lat": 25,
        "lon": 25,
        "time": -1,
    },  # this tells xarray to open the dataset as a dask array
)
ds
[4]:
<xarray.Dataset>
Dimensions:  (lat: 25, lon: 53, time: 2920)
Coordinates:
  * lat      (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0
  * lon      (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0
  * time     (time) datetime64[ns] 2013-01-01 ... 2014-12-31T18:00:00
Data variables:
    air      (time, lat, lon) float32 dask.array<chunksize=(2920, 25, 25), meta=np.ndarray>
Attributes:
    Conventions:  COARDS
    title:        4x daily NMC reanalysis (1948)
    description:  Data is from NMC initialized reanalysis\n(4x/day).  These a...
    platform:     Model
    references:   http://www.esrl.noaa.gov/psd/data/gridded/data.ncep.reanaly...

The repr for the air DataArray shows the dask repr.

[5]:
ds.air
[5]:
<xarray.DataArray 'air' (time: 2920, lat: 25, lon: 53)>
dask.array<open_dataset-8f4d41e7ebef96085c123d71866d9e37air, shape=(2920, 25, 53), dtype=float32, chunksize=(2920, 25, 25), chunktype=numpy.ndarray>
Coordinates:
  * lat      (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0
  * lon      (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0
  * time     (time) datetime64[ns] 2013-01-01 ... 2014-12-31T18:00:00
Attributes:
    long_name:     4xDaily Air temperature at sigma level 995
    units:         degK
    precision:     2
    GRIB_id:       11
    GRIB_name:     TMP
    var_desc:      Air temperature
    dataset:       NMC Reanalysis
    level_desc:    Surface
    statistic:     Individual Obs
    parent_stat:   Other
    actual_range:  [185.16 322.1 ]
[6]:
ds.air.chunks
[6]:
((2920,), (25,), (25, 25, 3))

Tip: All variables in a Dataset need not have the same chunk size along common dimensions.

[7]:
mean = ds.air.mean("time")  # no activity on dashboard
mean  # contains a dask array
[7]:
<xarray.DataArray 'air' (lat: 25, lon: 53)>
dask.array<mean_agg-aggregate, shape=(25, 53), dtype=float32, chunksize=(25, 25), chunktype=numpy.ndarray>
Coordinates:
  * lat      (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0
  * lon      (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0

This is true for all xarray operations including slicing

[8]:
ds.air.isel(lon=1, lat=20)
[8]:
<xarray.DataArray 'air' (time: 2920)>
dask.array<getitem, shape=(2920,), dtype=float32, chunksize=(2920,), chunktype=numpy.ndarray>
Coordinates:
    lat      float32 25.0
    lon      float32 202.5
  * time     (time) datetime64[ns] 2013-01-01 ... 2014-12-31T18:00:00
Attributes:
    long_name:     4xDaily Air temperature at sigma level 995
    units:         degK
    precision:     2
    GRIB_id:       11
    GRIB_name:     TMP
    var_desc:      Air temperature
    dataset:       NMC Reanalysis
    level_desc:    Surface
    statistic:     Individual Obs
    parent_stat:   Other
    actual_range:  [185.16 322.1 ]

and more complicated operations…

Parallel/streaming/lazy computation using dask.array with Xarray

Xarray seamlessly wraps dask so all computation is deferred until explicitly requested

[9]:
mean = ds.air.mean("time")  # no activity on dashboard
mean  # contains a dask array
[9]:
<xarray.DataArray 'air' (lat: 25, lon: 53)>
dask.array<mean_agg-aggregate, shape=(25, 53), dtype=float32, chunksize=(25, 25), chunktype=numpy.ndarray>
Coordinates:
  * lat      (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0
  * lon      (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0

This is true for all xarray operations including slicing

[10]:
timeseries = (
    ds.air.rolling(time=5).mean().isel(lon=1, lat=20)
)  # no activity on dashboard
timeseries  # contains dask array
[10]:
<xarray.DataArray (time: 2920)>
dask.array<getitem, shape=(2920,), dtype=float32, chunksize=(2918,), chunktype=numpy.ndarray>
Coordinates:
    lat      float32 25.0
    lon      float32 202.5
  * time     (time) datetime64[ns] 2013-01-01 ... 2014-12-31T18:00:00
[11]:
timeseries = ds.air.rolling(time=5).mean()  # no activity on dashboard
timeseries  # contains dask array
[11]:
<xarray.DataArray (time: 2920, lat: 25, lon: 53)>
dask.array<where, shape=(2920, 25, 53), dtype=float32, chunksize=(2918, 25, 25), chunktype=numpy.ndarray>
Coordinates:
  * lat      (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0
  * lon      (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0
  * time     (time) datetime64[ns] 2013-01-01 ... 2014-12-31T18:00:00

Getting concrete values from dask arrays

At some point, you will want to actually get concrete values from dask.

There are two ways to compute values on dask arrays. These concrete values are usually numpy arrays but could be a pydata/sparse array for example.

  1. .compute() returns an xarray object

  2. .load() replaces the dask array in the xarray object with a numpy array. This is equivalent to ds = ds.compute()

[12]:
computed = mean.compute()  # activity on dashboard
computed  # has real numpy values
[12]:
<xarray.DataArray 'air' (lat: 25, lon: 53)>
array([[260.37564, 260.1826 , 259.88593, ..., 250.81511, 251.93733,
        253.43741],
       [262.7337 , 262.7936 , 262.7489 , ..., 249.75496, 251.5852 ,
        254.35849],
       [264.7681 , 264.3271 , 264.0614 , ..., 250.60707, 253.58247,
        257.71475],
       ...,
       [297.64932, 296.95294, 296.62912, ..., 296.81033, 296.28793,
        295.81622],
       [298.1287 , 297.93646, 297.47006, ..., 296.8591 , 296.77686,
        296.44348],
       [298.36594, 298.38593, 298.11386, ..., 297.33777, 297.28104,
        297.30502]], dtype=float32)
Coordinates:
  * lat      (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0
  * lon      (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0

Note that mean still contains a dask array

[13]:
mean
[13]:
<xarray.DataArray 'air' (lat: 25, lon: 53)>
dask.array<mean_agg-aggregate, shape=(25, 53), dtype=float32, chunksize=(25, 25), chunktype=numpy.ndarray>
Coordinates:
  * lat      (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0
  * lon      (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0

But if we call .load(), mean will now contain a numpy array

[14]:
mean.load()
[14]:
<xarray.DataArray 'air' (lat: 25, lon: 53)>
array([[260.37564, 260.1826 , 259.88593, ..., 250.81511, 251.93733,
        253.43741],
       [262.7337 , 262.7936 , 262.7489 , ..., 249.75496, 251.5852 ,
        254.35849],
       [264.7681 , 264.3271 , 264.0614 , ..., 250.60707, 253.58247,
        257.71475],
       ...,
       [297.64932, 296.95294, 296.62912, ..., 296.81033, 296.28793,
        295.81622],
       [298.1287 , 297.93646, 297.47006, ..., 296.8591 , 296.77686,
        296.44348],
       [298.36594, 298.38593, 298.11386, ..., 297.33777, 297.28104,
        297.30502]], dtype=float32)
Coordinates:
  * lat      (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0
  * lon      (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0

Let’s check that again…

[15]:
mean
[15]:
<xarray.DataArray 'air' (lat: 25, lon: 53)>
array([[260.37564, 260.1826 , 259.88593, ..., 250.81511, 251.93733,
        253.43741],
       [262.7337 , 262.7936 , 262.7489 , ..., 249.75496, 251.5852 ,
        254.35849],
       [264.7681 , 264.3271 , 264.0614 , ..., 250.60707, 253.58247,
        257.71475],
       ...,
       [297.64932, 296.95294, 296.62912, ..., 296.81033, 296.28793,
        295.81622],
       [298.1287 , 297.93646, 297.47006, ..., 296.8591 , 296.77686,
        296.44348],
       [298.36594, 298.38593, 298.11386, ..., 297.33777, 297.28104,
        297.30502]], dtype=float32)
Coordinates:
  * lat      (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0
  * lon      (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0

Tip: .persist() loads the values into distributed RAM. This is useful if you will be repeatedly using a dataset for computation but it is too large to load into local memory. You will see a persistent task on the dashboard.

See https://docs.dask.org/en/latest/api.html#dask.persist for more

Extracting underlying data: .values vs .data

There are two ways to pull out the underlying data in an xarray object.

  1. .values will always return a NumPy array. For dask-backed xarray objects, this means that compute will always be called

  2. .data will return a Dask array

Exercise

Try extracting a dask array from ds.air

Now extract a NumPy array from ds.air. Do you see compute activity on your dashboard?

Xarray data structures are first-class dask collections.

This means you can do things like dask.compute(xarray_object), dask.visualize(xarray_object), dask.persist(xarray_object). This works for both DataArrays and Datasets

Visualize the task graph for mean

Visualize the task graph for mean.data. Is that the same as the above graph?

Automatic parallelization with apply_ufunc and map_blocks

Almost all of xarray’s built-in operations work on Dask arrays.

Sometimes analysis calls for functions that aren’t in xarray’s API (e.g. scipy). There are three ways to apply these functions in parallel on each block of your xarray object:

  1. Extract Dask arrays from xarray objects (.data) and use Dask directly e.g. (apply_gufunc, map_blocks, map_overlap, or blockwise)

  2. Use xarray.apply_ufunc() to apply functions that consume and return NumPy arrays.

  3. Use xarray.map_blocks(), Dataset.map_blocks() or DataArray.map_blocks() to apply functions that consume and return xarray objects.

Which method you use ultimately depends on the type of input objects expected by the function you’re wrapping, and the level of performance or convenience you desire.

map_blocks

map_blocks is inspired by the dask.array function of the same name and lets you map a function on blocks of the xarray object (including Datasets!).

At compute time, your function will receive an xarray object with concrete (computed) values along with appropriate metadata. This function should return an xarray object.

Here is an example

[16]:
def time_mean(obj):
    # use xarray's convenient API here
    # you could convert to a pandas dataframe and use pandas' extensive API
    # or use .plot() and plt.savefig to save visualizations to disk in parallel.
    return obj.mean("lat")


ds.map_blocks(time_mean)  # this is lazy!
[16]:
<xarray.Dataset>
Dimensions:  (lon: 53, time: 2920)
Coordinates:
  * time     (time) datetime64[ns] 2013-01-01 ... 2014-12-31T18:00:00
  * lon      (lon) float64 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0
Data variables:
    air      (time, lon) float32 dask.array<chunksize=(2920, 25), meta=np.ndarray>
[17]:
# this will calculate values and will return True if the computation works as expected
ds.map_blocks(time_mean).identical(ds.mean("lat"))
[17]:
True

Exercise

Try applying the following function with map_blocks. Specify scale as an argument and offset as a kwarg.

The docstring should help: https://xarray.pydata.org/en/stable/generated/xarray.map_blocks.html

def time_mean_scaled(obj, scale, offset):
    return obj.mean("lat") * scale + offset

More advanced functions

map_blocks needs to know what the returned object looks like exactly. It does so by passing a 0-shaped xarray object to the function and examining the result. This approach cannot work in all cases For such advanced use cases, map_blocks allows a template kwarg. See https://xarray.pydata.org/en/latest/dask.html#map-blocks for more details

apply_ufunc

apply_ufunc is a more advanced wrapper that is designed to apply functions that expect and return NumPy (or other arrays). For example, this would include all of SciPy’s API. Since apply_ufunc operates on lower-level NumPy or Dask objects, it skips the overhead of using Xarray objects making it a good choice for performance-critical functions.

apply_ufunc can be a little tricky to get right since it operates at a lower level than map_blocks. On the other hand, Xarray uses apply_ufunc internally to implement much of its API, meaning that it is quite powerful!

A simple example

Simple functions that act independently on each value should work without any additional arguments. However dask handling needs to be explictly enabled

[18]:
%%expect_exception

squared_error = lambda x, y: (x - y) ** 2

xr.apply_ufunc(squared_error, ds.air, 1)
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-18-686bd005fe5a> in <module>
      1 squared_error = lambda x, y: (x - y) ** 2
      2
----> 3 xr.apply_ufunc(squared_error, ds.air, 1)

~/miniconda/envs/xarray/lib/python3.8/site-packages/xarray/core/computation.py in apply_ufunc(func, input_core_dims, output_core_dims, exclude_dims, vectorize, join, dataset_join, dataset_fill_value, keep_attrs, kwargs, dask, output_dtypes, output_sizes, meta, *args)
   1056         )
   1057     elif any(isinstance(a, DataArray) for a in args):
-> 1058         return apply_dataarray_vfunc(
   1059             variables_vfunc,
   1060             *args,

~/miniconda/envs/xarray/lib/python3.8/site-packages/xarray/core/computation.py in apply_dataarray_vfunc(func, signature, join, exclude_dims, keep_attrs, *args)
    231
    232     data_vars = [getattr(a, "variable", a) for a in args]
--> 233     result_var = func(*data_vars)
    234
    235     if signature.num_outputs > 1:

~/miniconda/envs/xarray/lib/python3.8/site-packages/xarray/core/computation.py in apply_variable_ufunc(func, signature, exclude_dims, dask, output_dtypes, output_sizes, keep_attrs, meta, *args)
    572     if any(isinstance(array, dask_array_type) for array in input_data):
    573         if dask == "forbidden":
--> 574             raise ValueError(
    575                 "apply_ufunc encountered a dask array on an "
    576                 "argument, but handling for dask arrays has not "

ValueError: apply_ufunc encountered a dask array on an argument, but handling for dask arrays has not been enabled. Either set the ``dask`` argument or load your data into memory first with ``.load()`` or ``.compute()``

There are two options for the dask kwarg.

  1. dask="allowed" Dask arrays are passed to the user function. This is a good choice if your function can handle dask arrays and won’t call compute explicitly.

  2. dask="parallelized". This applies the user function over blocks of the dask array using dask.array.blockwise. This is useful when your function cannot handle dask arrays natively (e.g. scipy API).

Since squared_error can handle dask arrays without computing them, we specify dask="allowed".

[19]:
sqer = xr.apply_ufunc(squared_error, ds.air, 1, dask="allowed",)
sqer  # dask-backed DataArray! with nice metadata!
[19]:
<xarray.DataArray 'air' (time: 2920, lat: 25, lon: 53)>
dask.array<pow, shape=(2920, 25, 53), dtype=float32, chunksize=(2920, 25, 25), chunktype=numpy.ndarray>
Coordinates:
  * lat      (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0
  * lon      (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0
  * time     (time) datetime64[ns] 2013-01-01 ... 2014-12-31T18:00:00

A more complicated example with a dask-aware function

For using more complex operations that consider some array values collectively, it’s important to understand the idea of core dimensions from NumPy’s generalized ufuncs. Core dimensions are defined as dimensions that should not be broadcast over. Usually, they correspond to the fundamental dimensions over which an operation is defined, e.g., the summed axis in np.sum. A good clue that core dimensions are needed is the presence of an axis argument on the corresponding NumPy function.

With apply_ufunc, core dimensions are recognized by name, and then moved to the last dimension of any input arguments before applying the given function. This means that for functions that accept an axis argument, you usually need to set axis=-1

Let’s use dask.array.mean as an example of a function that can handle dask arrays and uses an axis kwarg

[20]:
def time_mean(da):
    return xr.apply_ufunc(
        dask.array.mean,
        da,
        input_core_dims=[["time"]],
        dask="allowed",
        kwargs={"axis": -1},  # core dimensions are moved to the end
    )


time_mean(ds.air)
[20]:
<xarray.DataArray 'air' (lat: 25, lon: 53)>
dask.array<mean_agg-aggregate, shape=(25, 53), dtype=float32, chunksize=(25, 25), chunktype=numpy.ndarray>
Coordinates:
  * lat      (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0
  * lon      (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0
[21]:
ds.air.mean("time").identical(time_mean(ds.air))
[21]:
True

Automatically parallelizing dask-unaware functions

A very useful apply_ufunc feature is the ability to apply arbitrary functions in parallel to each block. This ability can be activated using dask="parallelized". Again xarray needs a lot of extra metadata, so depending on the function, extra arguments such as output_dtypes and output_sizes may be necessary.

We will use scipy.integrate.trapz as an example of a function that cannot handle dask arrays and requires a core dimension.

[22]:
import scipy as sp
import scipy.integrate

sp.integrate.trapz(ds.air.data)  # does NOT return a dask array
/home/travis/miniconda/envs/xarray/lib/python3.8/site-packages/dask/array/core.py:1352: FutureWarning: The `numpy.trapz` function is not implemented by Dask array. You may want to use the da.map_blocks function or something similar to silence this warning. Your code may stop working in a future release.
  warnings.warn(
[22]:
array([[12588.54  , 12582.26  , 12671.649 , ..., 15374.26  , 15430.039 ,
        15493.165 ],
       [12571.841 , 12567.279 , 12654.569 , ..., 15355.915 , 15413.14  ,
        15477.346 ],
       [12584.62  , 12537.54  , 12644.909 , ..., 15347.77  , 15399.9   ,
        15460.965 ],
       ...,
       [12709.4795, 12638.4795, 12810.2295, ..., 15416.831 , 15459.581 ,
        15510.4795],
       [12726.679 , 12634.4795, 12794.63  , ..., 15401.4795, 15454.13  ,
        15511.4795],
       [12767.33  , 12630.78  , 12754.531 , ..., 15446.33  , 15495.53  ,
        15538.18  ]], dtype=float32)

Exercise

Use apply_ufunc to apply sp.integrate.trapz along the time axis so that you get a dask array returned. You will need to specify dask="parallelized" and output_dtypes (a list of dtypes per returned variable).