使用 Xarray 和 Dask 加快选择组合 netCDF 文件中的元素

Job*_*obS 6 dask python-xarray

我是 Xarray 和 Dask 的新手,并尝试访问以 3H 间隔存储全球洋流速度的多个 netCDF 文件。每个netCDF文件覆盖一个时间间隔的1/4度分辨率的网格数据:

NetCDF dimension information:
Name: time
    size: 1
    type: dtype('float32')
    _FillValue: 9.96921e+36
    units: 'days since 1950-01-01 00:00:00 UTC'
    calendar: 'julian'
    axis: 'T'
Name: lat
    size: 720
    type: dtype('float32')
    _FillValue: 9.96921e+36
    units: 'degrees_north'
    axis: 'Y'
Name: lon
    size: 1440
    type: dtype('float32')
    _FillValue: 9.96921e+36
    units: 'degrees_east'
    axis: 'X'
NetCDF variable information:
Name: eastward_eulerian_current_velocity, northward_eulerian_current_velocity
    dimensions: ('time', 'lat', 'lon')
    size: 1036800
    type: dtype('float32')
    _FillValue: 9.96921e+36
    coordinates: 'lon lat'
    horizontal_scale_range: 'greater than 100 km'
    temporal_scale_range: '10 days'
    units: 'm s-1'
Run Code Online (Sandbox Code Playgroud)

为了评估船舶航线规划器中航线的旅行时间,我试图找到针对特定经度/纬度组合和时间指数选择当前速度值(向北和向东)的最快方法。

目前的方法:

首先,我使用 打开多个数据集xarray.open_mfdataset(),并在预处理后将其存储xarray.Dataset.to_netcdf(),同时保持类型 float32。然后,我重新打开数据集并使用 Dask 自动分块对其进行分块。对于连接 50 天的数据 ( n_days = 50),这会产生以下块:

  • “时间”:200
  • “长”:288
  • “纬度”:240

200 * 288 * 240 * 4 字节 (float32) * 2 个变量 = 110.6 MB(正确吗?)

n_days当这些文件很大时,存储和打开这些文件的成本很高。对于n_days = 50,存储的 netCDF 文件为 3.08 GB。

NetCDF dimension information:
Name: time
    size: 1
    type: dtype('float32')
    _FillValue: 9.96921e+36
    units: 'days since 1950-01-01 00:00:00 UTC'
    calendar: 'julian'
    axis: 'T'
Name: lat
    size: 720
    type: dtype('float32')
    _FillValue: 9.96921e+36
    units: 'degrees_north'
    axis: 'Y'
Name: lon
    size: 1440
    type: dtype('float32')
    _FillValue: 9.96921e+36
    units: 'degrees_east'
    axis: 'X'
NetCDF variable information:
Name: eastward_eulerian_current_velocity, northward_eulerian_current_velocity
    dimensions: ('time', 'lat', 'lon')
    size: 1036800
    type: dtype('float32')
    _FillValue: 9.96921e+36
    coordinates: 'lon lat'
    horizontal_scale_range: 'greater than 100 km'
    temporal_scale_range: '10 days'
    units: 'm s-1'
Run Code Online (Sandbox Code Playgroud)

预处理数据的函数xarray.open_mfdataset()

def load_current_data(start_date, n_days):
    # start_date: datetime(2016, 1, 1)
    # n_days: 50 --> loading 50 * 8 = 400 netCDF files
    # ds_fp: dataset filepath
    # local_paths: netCDF file pathslist

    if os.path.exists(ds_fp):
        return xr.open_dataset(ds_fp, chunks={'time': 'auto', 'lon': 'auto', 'lat': 'auto'})
    
        # Try opening current data locally, otherwise download from FTP server
    try:
        ds = xr.open_mfdataset(local_paths,
                               parallel=True,
                               combine='by_coords',
                               preprocess=convert_to_knots)
    except FileNotFoundError:
        # Download files from FTP server and save to local_paths
        ds = xr.open_mfdataset(local_paths,
                               parallel=True,
                               combine='by_coords',
                               preprocess=convert_to_knots)
    ds.to_netcdf(ds_fp)
    return xr.open_dataset(ds_fp, chunks={'time': 'auto', 'lon': 'auto', 'lat': 'auto'})
Run Code Online (Sandbox Code Playgroud)

load_current_data()返回一个 chunked xarray.Dataset,见下文。

<xarray.Dataset>
Dimensions:  (lat: 720, lon: 1440, time: 400)
Coordinates:
  * lon      (lon) float32 -179.875 -179.625 -179.375 ... 179.625 179.875
  * lat      (lat) float32 -89.875 -89.625 -89.375 ... 89.375 89.625 89.875
  * time     (time) object 2016-01-01 00:00:00 ... 2016-02-19 21:00:00
Data variables:
    u_knot   (time, lat, lon) float32 dask.array<chunksize=(200, 240, 288), meta=np.ndarray>
    v_knot   (time, lat, lon) float32 dask.array<chunksize=(200, 240, 288), meta=np.ndarray>
Run Code Online (Sandbox Code Playgroud)

为了获得网格点的实际当前速度,我编写了下面的函数。它需要 lon/lat 索引和所需 netCDF 文件的日期时间(例如 2016-01-01 00:00:00)。该函数加载数据并将load_current_data()其保存到内存中。然而,这只有在适合内存的情况下才有可能。最后,使用 来选择与 lon/lat 索引和日期时间参数相对应的值xarray.Dataset.isel().load(),请参见下文。

我使用 dask.cache 来缓存计算。

def convert_to_knots(ds):
    ds.attrs = {}
    arr2d = np.float32(np.ones((720, 1440)) * 1.94384)
    ds['u_knot'] = arr2d * ds['eastward_eulerian_current_velocity']
    ds['v_knot'] = arr2d * ds['northward_eulerian_current_velocity']
    ds = ds.drop_vars(['eastward_eulerian_current_velocity',
                       'eastward_eulerian_current_velocity_error',
                       'northward_eulerian_current_velocity',
                       'northward_eulerian_current_velocity_error'])
    return ds
Run Code Online (Sandbox Code Playgroud)

是否可以提高这段代码的性能?并且,具体来说:

  1. 对于较大的数据集,我使用的方法persist()成本更高。是否有更好的方法可以persist()在不使用数据或以不同方式使用数据的情况下有效地访问数据?我也尝试过load()将数据加载到内存中。这要快得多,但是,它不适用于大于内存的数据集。
  2. 设置compute=False有助于xarray.Dataset.to_netcdf()提高性能吗?存储的 netCDF 文件要小得多,因此存储和打开速度更快。需要在稍后阶段计算值,但我没能做到这一点。(使用返回的dask.delayed对象?)