Source code for kerchunk.grib2

import base64
import copy
import io
import logging
from collections import defaultdict
from typing import Iterable, List, Dict, Set

import ujson

try:
    import cfgrib
except ModuleNotFoundError as err:  # pragma: no cover
    if err.name == "cfgrib":
        raise ImportError(
            "cfgrib is needed to kerchunk GRIB2 files. Please install it with "
            "`conda install -c conda-forge cfgrib`. See https://github.com/ecmwf/cfgrib "
            "for more details."
        )

import fsspec
import zarr
import xarray
import numpy as np

from kerchunk.utils import class_factory, _encode_for_JSON
from kerchunk.codecs import GRIBCodec
from kerchunk.combine import MultiZarrToZarr, drop


# cfgrib copies over certain GRIB attributes
# but renames them to CF-compliant values
ATTRS_TO_COPY_OVER = {
    "long_name": "GRIB_name",
    "units": "GRIB_units",
    "standard_name": "GRIB_cfName",
}

logger = logging.getLogger("grib2-to-zarr")


def _split_file(f: io.FileIO, skip=0):
    if hasattr(f, "size"):
        size = f.size
    else:
        size = f.seek(0, 2)
        f.seek(0)
    part = 0

    while f.tell() < size:
        logger.debug(f"extract part {part + 1}")
        head = f.read(1024)
        if b"GRIB" not in head:
            f.seek(-4, 1)
            continue
        ind = head.index(b"GRIB")
        start = f.tell() - len(head) + ind
        part_size = int.from_bytes(head[ind + 12 : ind + 16], "big")
        f.seek(start)
        yield start, part_size, f.read(part_size)
        part += 1
        if skip and part >= skip:
            break


def _store_array(store, z, data, var, inline_threshold, offset, size, attr):
    nbytes = data.dtype.itemsize
    for i in data.shape:
        nbytes *= i

    shape = tuple(data.shape or ())
    if nbytes < inline_threshold:
        logger.debug(f"Store {var} inline")
        d = z.create_dataset(
            name=var,
            shape=shape,
            chunks=shape,
            dtype=data.dtype,
            fill_value=attr.get("missingValue", None),
            compressor=False,
        )
        if hasattr(data, "tobytes"):
            b = data.tobytes()
        else:
            b = data.build_array().tobytes()
        try:
            # easiest way to test if data is ascii
            b.decode("ascii")
        except UnicodeDecodeError:
            b = b"base64:" + base64.b64encode(data)
        store[f"{var}/0"] = b.decode("ascii")
    else:
        logger.debug(f"Store {var} reference")
        d = z.create_dataset(
            name=var,
            shape=shape,
            chunks=shape,
            dtype=data.dtype,
            fill_value=attr.get("missingValue", None),
            filters=[GRIBCodec(var=var, dtype=str(data.dtype))],
            compressor=False,
            overwrite=True,
        )
        store[f"{var}/" + ".".join(["0"] * len(shape))] = ["{{u}}", offset, size]
    d.attrs.update(attr)


[docs] def scan_grib( url, common=None, storage_options=None, inline_threshold=100, skip=0, filter={}, ): """ Generate references for a GRIB2 file Parameters ---------- url: str File location common_vars: (depr, do not use) storage_options: dict For accessing the data, passed to filesystem inline_threshold: int If given, store array data smaller than this value directly in the output skip: int If non-zero, stop processing the file after this many messages filter: dict keyword filtering. For each key, only messages where the key exists and has the exact value or is in the given set, are processed. E.g., the cf-style filter ``{'typeOfLevel': 'heightAboveGround', 'level': 2}`` only keeps messages where heightAboveGround==2. Returns ------- list(dict): references dicts in Version 1 format, one per message in the file """ import eccodes storage_options = storage_options or {} logger.debug(f"Open {url}") # This is hardcoded a lot in cfgrib! # valid_time is added if "time" and "step" are present in time_dims # These are present by default # TIME_DIMS = ["step", "time", "valid_time"] out = [] with fsspec.open(url, "rb", **storage_options) as f: logger.debug(f"File {url}") for offset, size, data in _split_file(f, skip=skip): store = {} mid = eccodes.codes_new_from_message(data) m = cfgrib.cfmessage.CfMessage(mid) # It would be nice to just have a list of valid keys # There does not seem to be a nice API for this # 1. message_grib_keys returns keys coded in the message # 2. There exist "computed" keys, that are functions applied on the data # 3. There are also aliases! # e.g. "number" is an alias of "perturbationNumber", and cfgrib uses this alias # So we stick to checking membership in 'm', which ends up doing # a lot of reads. message_keys = set(m.message_grib_keys()) # The choices here copy cfgrib :( # message_keys.update(cfgrib.dataset.INDEX_KEYS) # message_keys.update(TIME_DIMS) # print("totalNumber" in cfgrib.dataset.INDEX_KEYS) # Adding computed keys adds a lot that isn't added by cfgrib # message_keys.extend(m.computed_keys) shape = (m["Ny"], m["Nx"]) # thank you, gribscan native_type = eccodes.codes_get_native_type(m.codes_id, "values") data_size = eccodes.codes_get_size(m.codes_id, "values") coordinates = [] good = True for k, v in (filter or {}).items(): if k not in m: good = False elif isinstance(v, (list, tuple, set)): if m[k] not in v: good = False elif m[k] != v: good = False if good is False: continue z = zarr.open_group(store) global_attrs = { f"GRIB_{k}": m[k] for k in cfgrib.dataset.GLOBAL_ATTRIBUTES_KEYS if k in m } if "GRIB_centreDescription" in global_attrs: # follow CF compliant renaming from cfgrib global_attrs["institution"] = global_attrs["GRIB_centreDescription"] z.attrs.update(global_attrs) if data_size < inline_threshold: # read the data vals = m["values"].reshape(shape) else: # dummy array to match the required interface vals = np.empty(shape, dtype=native_type) assert vals.size == data_size attrs = { # Follow cfgrib convention and rename key f"GRIB_{k}": m[k] for k in cfgrib.dataset.DATA_ATTRIBUTES_KEYS + cfgrib.dataset.EXTRA_DATA_ATTRIBUTES_KEYS + cfgrib.dataset.GRID_TYPE_MAP.get(m["gridType"], []) if k in m } for k, v in ATTRS_TO_COPY_OVER.items(): if v in attrs: attrs[k] = attrs[v] # try to use cfVarName if available, # otherwise use the grib shortName varName = m["cfVarName"] if varName in ("undef", "unknown"): varName = m["shortName"] _store_array(store, z, vals, varName, inline_threshold, offset, size, attrs) if "typeOfLevel" in message_keys and "level" in message_keys: name = m["typeOfLevel"] coordinates.append(name) # convert to numpy scalar, so that .tobytes can be used for inlining # dtype=float is hardcoded in cfgrib data = np.array(m["level"], dtype=float)[()] try: attrs = cfgrib.dataset.COORD_ATTRS[name] except KeyError: logger.debug(f"Couldn't find coord {name} in dataset") attrs = {} attrs["_ARRAY_DIMENSIONS"] = [] _store_array( store, z, data, name, inline_threshold, offset, size, attrs ) dims = ( ["y", "x"] if m["gridType"] in cfgrib.dataset.GRID_TYPES_2D_NON_DIMENSION_COORDS else ["latitude", "longitude"] ) z[varName].attrs["_ARRAY_DIMENSIONS"] = dims for coord in cfgrib.dataset.COORD_ATTRS: coord2 = { "latitude": "latitudes", "longitude": "longitudes", "step": "step:int", }.get(coord, coord) try: x = m.get(coord2) except eccodes.WrongStepUnitError as e: logger.warning( "Ignoring coordinate '%s' for varname '%s', raises: eccodes.WrongStepUnitError(%s)", coord2, varName, e, ) continue if x is None: continue coordinates.append(coord) inline_extra = 0 if isinstance(x, np.ndarray) and x.size == data_size: if ( m["gridType"] in cfgrib.dataset.GRID_TYPES_2D_NON_DIMENSION_COORDS ): dims = ["y", "x"] x = x.reshape(vals.shape) else: dims = [coord] if coord == "latitude": x = x.reshape(vals.shape)[:, 0].copy() elif coord == "longitude": x = x.reshape(vals.shape)[0].copy() # force inlining of x/y/latitude/longitude coordinates. # since these are derived from analytic formulae # and are not stored in the message inline_extra = x.nbytes + 1 elif np.isscalar(x): # convert python scalars to numpy scalar # so that .tobytes can be used for inlining x = np.array(x)[()] dims = [] else: x = np.array([x]) dims = [coord] attrs = cfgrib.dataset.COORD_ATTRS[coord] _store_array( store, z, x, coord, inline_threshold + inline_extra, offset, size, attrs, ) z[coord].attrs["_ARRAY_DIMENSIONS"] = dims if coordinates: z.attrs["coordinates"] = " ".join(coordinates) out.append( { "version": 1, "refs": _encode_for_JSON(store), "templates": {"u": url}, } ) logger.debug("Done") return out
GribToZarr = class_factory(scan_grib) def example_combine( filter={"typeOfLevel": "heightAboveGround", "level": 2} ): # pragma: no cover """Create combined dataset of weather measurements at 2m height Ten consecutive timepoints from ten 120MB files on s3. Example usage: >>> tot = example_combine() >>> ds = xr.open_dataset("reference://", engine="zarr", backend_kwargs={ ... "consolidated": False, ... "storage_options": {"fo": tot, "remote_options": {"anon": True}}}) """ files = [ "s3://noaa-hrrr-bdp-pds/hrrr.20190101/conus/hrrr.t22z.wrfsfcf01.grib2", "s3://noaa-hrrr-bdp-pds/hrrr.20190101/conus/hrrr.t23z.wrfsfcf01.grib2", "s3://noaa-hrrr-bdp-pds/hrrr.20190102/conus/hrrr.t00z.wrfsfcf01.grib2", "s3://noaa-hrrr-bdp-pds/hrrr.20190102/conus/hrrr.t01z.wrfsfcf01.grib2", "s3://noaa-hrrr-bdp-pds/hrrr.20190102/conus/hrrr.t02z.wrfsfcf01.grib2", "s3://noaa-hrrr-bdp-pds/hrrr.20190102/conus/hrrr.t03z.wrfsfcf01.grib2", "s3://noaa-hrrr-bdp-pds/hrrr.20190102/conus/hrrr.t04z.wrfsfcf01.grib2", "s3://noaa-hrrr-bdp-pds/hrrr.20190102/conus/hrrr.t05z.wrfsfcf01.grib2", "s3://noaa-hrrr-bdp-pds/hrrr.20190102/conus/hrrr.t06z.wrfsfcf01.grib2", ] so = {"anon": True, "default_cache_type": "readahead"} out = [scan_grib(u, storage_options=so, filter=filter) for u in files] out = sum(out, []) mzz = MultiZarrToZarr( out, remote_protocol="s3", preprocess=drop(("valid_time", "step")), remote_options=so, concat_dims=["time", "var"], identical_dims=["heightAboveGround", "latitude", "longitude"], ) return mzz.translate() def grib_tree( message_groups: Iterable[Dict], remote_options=None, ) -> Dict: """ Build a hierarchical data model from a set of scanned grib messages. The iterable input groups should be a collection of results from scan_grib. Multiple grib files can be processed together to produce an FMRC like collection. The time (reference_time) and step coordinates will be used as concat_dims in the MultiZarrToZarr aggregation. Each variable name will become a group with nested subgroups representing the grib step type and grib level. The resulting hierarchy can be opened as a zarr_group or a xarray datatree. Grib message variable names that decode as "unknown" are dropped Grib typeOfLevel attributes that decode as unknown are treated as a single group Grib steps that are missing due to WrongStepUnitError are patched with NaT The input message_groups should not be modified by this method Parameters ---------- message_groups: iterable[dict] a collection of zarr store like dictionaries as produced by scan_grib remote_options: dict remote options to pass to ZarrToMultiZarr Returns ------- list(dict): A new zarr store like dictionary for use as a reference filesystem mapper with zarr or xarray datatree """ # Hard code the filters in the correct order for the group hierarchy filters = ["stepType", "typeOfLevel"] # TODO allow passing a LazyReferenceMapper as output? zarr_store = {} zroot = zarr.open_group(store=zarr_store) aggregations: Dict[str, List] = defaultdict(list) aggregation_dims: Dict[str, Set] = defaultdict(set) unknown_counter = 0 for msg_ind, group in enumerate(message_groups): assert group["version"] == 1 gattrs = ujson.loads(group["refs"][".zattrs"]) coordinates = gattrs["coordinates"].split(" ") # Find the data variable vname = None for key, entry in group["refs"].items(): name = key.split("/")[0] if name not in [".zattrs", ".zgroup"] and name not in coordinates: vname = name break if vname is None: raise RuntimeError( f"Can not find a data var for msg# {msg_ind} in {group['refs'].keys()}" ) if vname == "unknown": # To resolve unknown variables add custom grib tables. # https://confluence.ecmwf.int/display/UDOC/Creating+your+own+local+definitions+-+ecCodes+GRIB+FAQ # If you process the groups from a single file in order, you can use the msg# to compare with the # IDX file. The idx files message index is 1 based where the grib_tree message count is zero based logger.warning( "Dropping unknown variable in msg# %d. Compare with the grib idx file to help identify it" " and build an ecCodes local grib definitions file to fix it.", msg_ind, ) unknown_counter += 1 continue logger.debug("Processing vname: %s", vname) dattrs = ujson.loads(group["refs"][f"{vname}/.zattrs"]) # filter order matters - it determines the hierarchy gfilters = {} for key in filters: attr_val = dattrs.get(f"GRIB_{key}") if attr_val is None: continue if attr_val == "unknown": logger.warning( "Found 'unknown' attribute value for key %s in var %s of msg# %s", key, vname, msg_ind, ) # Use unknown as a group or drop it? gfilters[key] = attr_val zgroup = zroot.require_group(vname) if "name" not in zgroup.attrs: zgroup.attrs["name"] = dattrs.get("GRIB_name") for key, value in gfilters.items(): if value: # Ignore empty string and None # name the group after the attribute values: surface, instant, etc zgroup = zgroup.require_group(value) # Add an attribute to give context zgroup.attrs[key] = value # Set the coordinates attribute for the group zgroup.attrs["coordinates"] = " ".join(coordinates) # add to the list of groups to multi-zarr aggregations[zgroup.path].append(group) # keep track of the level coordinate variables and their values for key, entry in group["refs"].items(): name = key.split("/")[0] if name == gfilters.get("typeOfLevel") and key.endswith("0"): if isinstance(entry, list): entry = tuple(entry) aggregation_dims[zgroup.path].add(entry) concat_dims = ["time", "step"] identical_dims = ["longitude", "latitude"] for path in aggregations.keys(): # Parallelize this step! catdims = concat_dims.copy() idims = identical_dims.copy() level_dimension_value_count = len(aggregation_dims.get(path, ())) level_group_name = path.split("/")[-1] if level_dimension_value_count == 0: logger.debug( "Path % has no value coordinate value associated with the level name %s", path, level_group_name, ) elif level_dimension_value_count == 1: idims.append(level_group_name) elif level_dimension_value_count > 1: # The level name should be the last element in the path catdims.insert(3, level_group_name) logger.info( "%s calling MultiZarrToZarr with idims %s and catdims %s", path, idims, catdims, ) mzz = MultiZarrToZarr( aggregations[path], remote_options=remote_options, concat_dims=catdims, identical_dims=idims, ) group = mzz.translate() for key, value in group["refs"].items(): if key not in [".zattrs", ".zgroup"]: zarr_store[f"{path}/{key}"] = value # Force all stored values to decode as string, not bytes. String should be correct. # ujson will reject bytes values by default. # Using 'reject_bytes=False' one write would fail an equality check on read. zarr_store = { key: (val.decode() if isinstance(val, bytes) else val) for key, val in zarr_store.items() } # TODO handle other kerchunk reference spec versions? result = dict(refs=zarr_store, version=1) return result def correct_hrrr_subhf_step(group: Dict) -> Dict: """ Overrides the definition of the "step" variable. Sets the value equal to the `valid_time - time` in hours as a floating point value. This fixes issues with the HRRR SubHF grib2 step as read by cfgrib via scan_grib. The result is a deep copy, the original data is unmodified. Parameters ---------- group: dict the zarr group store for a single grib message Returns ------- dict: A new zarr store like dictionary for use as a reference filesystem mapper with zarr or xarray datatree """ group = copy.deepcopy(group) group["refs"]["step/.zarray"] = ( '{"chunks":[],"compressor":null,"dtype":"<f8","fill_value":"NaN","filters":null,"order":"C",' '"shape":[],"zarr_format":2}' ) group["refs"]["step/.zattrs"] = ( '{"_ARRAY_DIMENSIONS":[],"long_name":"time since forecast_reference_time",' '"standard_name":"forecast_period","units":"hours"}' ) # add step to coords attrs = ujson.loads(group["refs"][".zattrs"]) if "step" not in attrs["coordinates"]: attrs["coordinates"] += " step" group["refs"][".zattrs"] = ujson.dumps(attrs) fo = fsspec.filesystem("reference", fo=group, mode="r") xd = xarray.open_dataset(fo.get_mapper(), engine="zarr", consolidated=False) correct_step = xd.valid_time.values - xd.time.values assert correct_step.shape == () step_float = correct_step.astype("timedelta64[s]").astype("float") / 3600.0 step_bytes = step_float.tobytes() try: enocded_val = step_bytes.decode("ascii") except UnicodeDecodeError: enocded_val = (b"base64:" + base64.b64encode(step_bytes)).decode("ascii") group["refs"]["step/0"] = enocded_val return group