Xarray Groupby to Extract Single Cell Timeseries#
A common task in microscopy data anlaysis is to track the changes of single cell properties over time.
After collecting data you segment and track your cells, so you know which pixels of the image correspond to what cell. But, how do you efficiently convert that mask and fluoresence layer into some a measurement over time for all of your cells.
For example how we would we make this plot:

Enter Groupby
#
Xarray has a solution for this exact problem! The solution is the groupby
. When you think of groupby
you likely think of grouping by one-dimensional, for example grouping by time point.
But Xarray has the power to extend beyond this because it supports multidimensional coordinates. This abstraction means that we can use our segmentation mask as multidimensional coordinate.
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
Simulate data#
To keep this example super simple this tutorial will use simulated data. In this case our simluation has spherical cells, with two flouresecent channels and each channel will oscillate in a sin wave, with randomly distributed amplitude, period and offset for each cell. Here is what that looks like. The next cell is collapsed but contains all the code to replicate the data in this example.
ds = make_groupby_data(seed=1024)
A quick look at the data#
Here you can see the simulated data. Spherical cells that move around according to a random walk. Each cell constitutively expresses both GFP and RFP with intensity varying over time.
fig, axs = plt.subplots(3, 10, sharex=True, sharey=True, figsize=(10,4))
cmaps = {0:'gray', 1:'Greens', 2:'Reds'}
for c, C in enumerate(ds.C):
for t, T in enumerate(np.arange(0,40, 40//10)):
axs[c,t].imshow(ds['images'].sel(T=T, C=C), cmap=cmaps[c])
# axs[c, t].axis('off')
axs[c,t].set_xticks([])
axs[c,t].set_yticks([])
fig.supxlabel("Time")
axs[0,0].set_ylabel("BF")
axs[1,0].set_ylabel("GFP")
axs[2,0].set_ylabel("RFP")
plt.tight_layout()

Looking at the Xarray repr we can see that the cell_id
was included as coordinate variable. In this case the knowledge of what pixel corresponds to what cell at any given timepoint is more metadata than data. So we have included it as a coordinate. However can also include it as a data variable if that makes more sense in your application. However, in order to do groupby you will need it to be a coordinate variable.
ds
<xarray.Dataset> Size: 273MB Dimensions: (C: 3, T: 40, Y: 512, X: 512) Coordinates: * C (C) <U3 36B 'BF' 'GFP' 'RFP' cell_id (T, Y, X) uint16 21MB 0 0 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 Dimensions without coordinates: T, Y, X Data variables: images (T, C, Y, X) float64 252MB 2.493 115.6 0.0 0.0 ... 3.729 0.0 0.0
# taking a quick look at the masks
plt.figure()
plt.imshow(ds['cell_id'].sel(T=0), cmap='tab20')
plt.title("Cell masks")
plt.colorbar()
<matplotlib.colorbar.Colorbar at 0x11b0ce120>

Now we are ready to use groupby to extract single cell properties in time. Groupby works on a split-apply-combine principle. So the first step is to split up our data by individual cell. To do this we simply let Xarray know what coordinate variable to make into groups.
In this case there were 12 unique cells in our field of view (+1 group for the background with cell_id=0
) so our grouper will have 13 groups.
# construct the groupby
per_cell_groupby = ds.groupby(["cell_id"])
print(per_cell_groupby)
<DatasetGroupBy, grouped over 1 grouper(s), 13 groups in total:
'cell_id': UniqueGrouper('cell_id'), 13/13 groups with labels 0, 270, 409, ..., 2987, 3282, 3353>
To turn this into a reasonable plot we now need to apply an aggregation function. In our case of tracking a cell cycle marker that can be as simple as calling the mean, but can you write arbitrarily complex reductions and apply them.
# Aggregate over our spatial dimensions
spatial_averaged = per_cell_groupby.mean(["X", "Y"], fill_value=np.nan)
# drop the background
spatial_averaged = spatial_averaged.drop_sel({"cell_id": 0}).rename({"images":"cell_avg"})
spatial_averaged
<xarray.Dataset> Size: 12kB Dimensions: (T: 40, C: 3, cell_id: 12) Coordinates: * C (C) <U3 36B 'BF' 'GFP' 'RFP' * cell_id (cell_id) uint16 24B 270 409 1134 1235 ... 1976 2987 3282 3353 Dimensions without coordinates: T Data variables: cell_avg (T, C, cell_id) float64 12kB nan 1.188e+03 ... 2.222e+04 9.988e+03
Notice that the cell_id
coordinate is now one dimensional, reflecting the fact that we have removed the spatial component through our grouping and aggregation.
This mekes it easy to use Xarray’s plotting functionality to automatically plot single cell traces of GFP and RFP as a function of time.
spatial_averaged.sel(C=["GFP", "RFP"]).cell_avg.plot(hue="cell_id", col="C")
<xarray.plot.facetgrid.FacetGrid at 0x17aac9940>

While this is is a simple example, it demonstrates a powerful paradigm. You can apply arbitary aggregation functions in the groupby in order to easily compute any single cell property as a function of time.
For a small example you might have been able to write this as a for loop, but has your datasets grow larger using Xarray’s groupby can yield significant performance improvements because of how it uses flox
to speed up out of memory groupbys. WHich you can read more about here: https://xarray.dev/blog/flox