import fsspec
import logging
import numcodecs
import numcodecs.abc
import numpy as np
import zarr
from fsspec.implementations.reference import LazyReferenceMapper
from kerchunk.utils import class_factory, dict_to_store, translate_refs_serializable
from kerchunk.codecs import AsciiTableCodec, VarArrCodec
try:
from astropy.wcs import WCS
from astropy.io import fits
except ModuleNotFoundError: # pragma: no cover
raise ImportError(
"astropy is required for kerchunking FIST files. Please install with "
"`pip/conda install astropy`."
)
logger = logging.getLogger("fits-to-zarr")
BITPIX2DTYPE = {
8: "uint8",
16: ">i2",
32: ">i4",
64: ">i8",
-32: ">f4",
-64: ">f8",
} # always bigendian
[docs]
def process_file(
url,
storage_options=None,
extension=None,
inline_threshold=100,
primary_attr_to_group=False,
out=None,
):
"""
Create JSON references for a single FITS file as a zarr group
Parameters
----------
url: str
Where the file is
storage_options: dict
How to load that file (passed to fsspec)
extension: list(int | str) | int | str or None
Which extensions to include. Can be ordinal integer(s), the extension name (str) or if None,
uses the first data extension
inline_threshold: int
(not yet implemented)
primary_attr_to_group: bool
Whether the output top-level group contains the attributes of the primary extension
(which often contains no data, just a general description)
out: dict-like or None
This allows you to supply an fsspec.implementations.reference.LazyReferenceMapper
to write out parquet as the references get filled, or some other dictionary-like class
to customise how references get stored
Returns
-------
dict of the references
"""
from astropy.io import fits
storage_options = storage_options or {}
out = out or {}
store = dict_to_store(out)
g = zarr.open_group(store=store, zarr_format=2)
with fsspec.open(url, mode="rb", **storage_options) as f:
infile = fits.open(f, do_not_scale_image_data=True)
if extension is None:
found = False
for i, hdu in enumerate(infile):
if hdu.header.get("NAXIS", 0) > 0:
extension = [i]
found = True
break
if not found:
raise ValueError("No data extensions")
elif isinstance(extension, (int, str)):
extension = [extension]
for ext in extension:
hdu = infile[ext]
hdu.header.__str__() # causes fixing of invalid cards
attrs = dict(hdu.header)
kwargs = {"compressor": None}
if hdu.is_image:
# for images/cubes (i.e., ndarrays with simple type)
nax = hdu.header["NAXIS"]
shape = tuple(int(hdu.header[f"NAXIS{i}"]) for i in range(nax, 0, -1))
dtype = BITPIX2DTYPE[hdu.header["BITPIX"]]
length = np.dtype(dtype).itemsize
for s in shape:
length *= s
if "BSCALE" in hdu.header or "BZERO" in hdu.header:
kwargs["filters"] = [
numcodecs.FixedScaleOffset(
offset=float(hdu.header.get("BZERO", 0)),
scale=float(hdu.header.get("BSCALE", 1)),
astype=dtype,
dtype=float,
)
]
else:
kwargs["filters"] = None
elif isinstance(hdu, fits.hdu.table.TableHDU):
# ascii table
spans = hdu.columns._spans
outdtype = [
[name, hdu.columns[name].format.recformat]
for name in hdu.columns.names
]
indtypes = [
[name, f"S{i + 1}"] for name, i in zip(hdu.columns.names, spans)
]
nrows = int(hdu.header["NAXIS2"])
shape = (nrows,)
kwargs["filters"] = [AsciiTableCodec(indtypes, outdtype)]
dtype = [tuple(d) for d in outdtype]
length = (sum(spans) + len(spans)) * nrows
elif isinstance(hdu, fits.hdu.table.BinTableHDU):
# binary table
dtype = hdu.columns.dtype.newbyteorder(">") # always big endian
nrows = int(hdu.header["NAXIS2"])
shape = (nrows,)
# if hdu.fileinfo()["datSpan"] > length
if any(_.format.startswith(("P", "Q")) for _ in hdu.columns):
# contains var fields
length = hdu.fileinfo()["datSpan"]
dt2 = [
(
(name, "O")
if hdu.columns[name].format.startswith(("P", "Q"))
else (name, str(dtype[name].base))
+ ((dtype[name].shape,) if dtype[name].shape else ())
)
for name in dtype.names
]
types = {
name: hdu.columns[name].format[1]
for name in dtype.names
if hdu.columns[name].format.startswith(("P", "Q"))
}
kwargs["compressor"] = VarArrCodec(
str(dtype), str(dt2), nrows, types
)
kwargs["fill_value"] = None
dtype = dt2
else:
length = dtype.itemsize * nrows
kwargs["filters"] = None
else:
logger.warning(f"Skipping unknown extension type: {hdu}")
continue
# one chunk for whole thing.
# TODO: we could sub-chunk on biggest dimension
name = hdu.name or str(ext)
arr = g.create_array(
name=name, dtype=dtype, shape=shape, chunks=shape, **kwargs
)
arr.attrs.update(
{
k: str(v) if not isinstance(v, (int, float, str)) else v
for k, v in attrs.items()
if k != "COMMENT"
}
)
arr.attrs["_ARRAY_DIMENSIONS"] = ["z", "y", "x"][-len(shape) :]
loc = hdu.fileinfo()["datLoc"]
parts = ".".join(["0"] * len(shape))
out[f"{name}/{parts}"] = [url, loc, length]
if primary_attr_to_group:
# copy attributes of primary extension to top-level group
hdu = infile[0]
hdu.header.__str__()
g.attrs.update(
{
k: str(v) if not isinstance(v, (int, float, str)) else v
for k, v in dict(hdu.header).items()
if k != "COMMENT"
}
)
if isinstance(out, LazyReferenceMapper):
out.flush()
out = translate_refs_serializable(out)
return out
FitsToZarr = class_factory(process_file)
def add_wcs_coords(hdu, zarr_group=None, dataset=None, dtype="float32"):
"""Using FITS WCS, create materialised coordinate arrays
This may triple the data footprint of the data, as the coordinates can easily
be as big as the data itsel.
Must provide zarr_group or dataset
Parameters
----------
hdu: astropy.io.fits.HDU or dict
Input with WCS header information. If a dict, it is {key: attribute} of the data.
zarr_group: zarr.Group
To write the new arrays into
dataset: xr.Dataset
To create new coordinate arrays in; this is not necessarily written anywhere
dtype: str
Output numpy dtype
Returns
-------
If dataset is given, returns the modified dataset.
"""
if zarr_group is None and dataset is None:
raise ValueError("please provide a zarr group or xarray dataset")
if isinstance(hdu, dict):
# assume dict-like
head = fits.Header()
hdu2 = hdu.copy()
hdu2.pop("COMMENT", None) # comment fields can be non-standard
head.update(hdu2)
hdu = fits.PrimaryHDU(header=head)
elif not isinstance(hdu, fits.hdu.base._BaseHDU):
raise TypeError("`hdu` must be a FITS HDU or dict")
nax = hdu.header["NAXIS"]
shape = tuple(int(hdu.header[f"NAXIS{i}"]) for i in range(nax, 0, -1))
wcs = WCS(hdu)
coords = [
coo.ravel() for coo in np.meshgrid(*(np.arange(sh) for sh in shape))
] # ?[::-1]
world_coords = wcs.pixel_to_world(*coords)
for i, (name, world_coord) in enumerate(zip(wcs.axis_type_names, world_coords)):
dims = ["z", "y", "x"][3 - len(shape) :]
attrs = {
"unit": world_coord.unit.name,
"type": hdu.header[f"CTYPE{i + 1}"],
"_ARRAY_DIMENSIONS": dims,
}
if zarr_group is not None:
arr = zarr_group.empty(
name, shape=shape, chunks=shape, dtype=dtype, exists_ok=True
)
arr.attrs.update(attrs)
arr[:] = world_coord.value.reshape(shape)
if dataset is not None:
import xarray as xr
coo = xr.Coordinate(
data=world_coord.value.reshape(shape), dims=dims, attrs=attrs
)
dataset = dataset.assign_coordinates(name=coo)
if dataset is not None:
return dataset