import os
import copy
import logging
import numpy as np
import xarray as xr
from pathlib import Path
import astropy.units as u
from ForMoSA.core.errors import ForMoSAError
from ForMoSA.core.enums import WavelengthUnit
from ForMoSA.core.loggings import setup_logging
from ForMoSA.grid.grid_loader import GridLoader
[docs]
class ModelGrid:
'''
Representation of a model grid with metadata.
Handles loading from file and exposing basic properties.
Parameters
----------
dataset : xr.Dataset
Dataset containing the model grid
model_path : str | os.PathLike | None
Path to the model grid file
logger : logging.Logger
Logger instance for logging
log_level : str
Level of the logger
display_unit : WavelengthUnit
Unit of the wavelength to display
Notes
-----
Authors: Allan Denis
'''
def __init__(self, dataset: xr.Dataset, model_path: str | os.PathLike | None = None, logger: logging.Logger | None = None, log_level: str = "INFO", display_unit: WavelengthUnit = WavelengthUnit.MICROMETER) -> None:
self._logger = logger if logger is not None else setup_logging(log_level, name='ModelGrid')
# Validation
try:
GridLoader._validate_model_grid_dataset(dataset)
except ForMoSAError as e:
raise ForMoSAError(e, self.logger)
self._grid = dataset
self._model_path = model_path
self._display_unit = display_unit
self._grid.attrs['grid_type'] = self.GridType
if 'wave_unit' not in self.attrs.keys():
self._grid.attrs['wave_unit'] = str(WavelengthUnit.MICROMETER.unit)
self._logger.warning(f"Wavelength unit not found in grid attributes. Setting to default: {WavelengthUnit.MICROMETER.unit}")
self.logger.info(f' Grid dimensions: {tuple(self.dimensions)} {tuple(self.dims)}')
# ================================================
# Representation
# ================================================
def __repr__(self):
return f"<ModelGrid name={self.grid_name} path={self.model_path} shape={self.grid['grid'].shape}>"
def __getstate__(self):
state = self.__dict__.copy()
level_int = self._logger.level
state['__pickle_log_level'] = 'OFF' if level_int >= 100 else logging.getLevelName(level_int)
state['__pickle_log_name'] = self._logger.name.removeprefix('ForMoSA.')
state['_logger'] = None
return state
def __setstate__(self, state):
log_level = state.pop('__pickle_log_level', 'INFO')
log_name = state.pop('__pickle_log_name', 'ModelGrid')
self.__dict__.update(state)
self._logger = setup_logging(level=log_level, name=log_name)
# ================================================
# Properties
# ================================================
@property
def suffix(self) -> str:
"""Suffix used for saving (Overriden in subgrid)."""
return 'native'
@property
def GridType(self) -> str:
"""OGrid type (overriden in subgrid)."""
return 'native'
@property
def grid(self) -> xr.Dataset:
"""Grid as xr.Dataset."""
return self._grid
@property
def grid_as_dataarray(self) -> xr.DataArray:
"""Grid as xr.DataArray (more tailored to manipulations)."""
da = self._grid["grid"].copy()
# Propagate attributes of the Dataset
da.attrs.update(self.attrs)
return da
@property
def model_path(self) -> Path | str:
"""Path to the model."""
return Path(self._model_path).expanduser() if self._model_path is not None else 'in-memory-grid'
@property
def grid_name(self) -> str:
"""Name of the grid (Overriden in subgrid)."""
return str(self.model_path).split('/')[-1].split('.nc')[0]
@property
def logger(self) -> logging.Logger:
"""Logger."""
return self._logger
@property
def native_unit(self) -> u.core.PrefixUnit:
"""Native unit of the wavelength."""
return WavelengthUnit[self.attrs['wave_unit']].unit
@property
def unit(self) -> u.core.PrefixUnit:
"""Unit of the wavelength to display."""
return self._display_unit.unit
@property
def wave(self) -> np.ndarray:
"""Wavelength of the grid."""
return ((self.grid.coords['wavelength'].values * self.native_unit).to(self.unit)).value
@property
def res(self) -> np.ndarray:
"""Resolution of the grid."""
return self.attrs['res']
@grid.setter
def grid(self, grid_array):
"""Grid setter."""
self._grid = grid_array
return grid_array
@property
def attrs(self) -> dict:
"""Dictionary of attributes of the grid."""
return dict(self.grid.attrs)
@property
def keys(self) -> list:
"""Keys of the grid parameters."""
return self.attrs['key']
@property
def titles(self) -> list:
"""Names of the grid parameters."""
return self.attrs['title']
@property
def key_values(self) -> dict:
"""Values taken by the grid parameters."""
values = {}
for key in self.keys:
values[key] = np.atleast_1d(self.grid[key].values)
return values
@property
def lims_params_grid(self):
"""Limits of grid parameters."""
return {par : [min(self.key_values[par]), max(self.key_values[par])] for par in self.keys}
@property
def nyquist(self) -> np.ndarray:
"""Nyquist sampling."""
if self.wave is None or len(self.wave) < 2:
return self.wave
diff = np.diff(self.wave, append=(2*self.wave[-1]-self.wave[-2]))
return self.wave / (2 * diff)
@property
def effective_resolution(self) -> np.ndarray:
"""Effective resolution beeing the minimum between Nyquist sampling and resolution."""
return np.minimum(self.res, self.nyquist)
@property
def n_grids(self) -> int:
"""Number of grids."""
return self.grid.grid[0,].size
@property
def size(self) -> int:
"""Size of the grid."""
return self.grid.data_vars['grid'].size
@property
def dims(self) -> list[str]:
"""List of names for each dimension of the grid."""
return list(self.grid.dims.keys())
@property
def dimensions(self) -> list[int]:
"""List of number of points for each dimension."""
return list(self.grid.dims.values())
# ================================================
# Class methods
# ================================================
[docs]
@classmethod
def from_file(cls, path: str | os.PathLike, logger: logging.Logger | None = None, log_level: str = 'INFO', display_unit: WavelengthUnit = WavelengthUnit.MICROMETER):
'''
Generate grid from file.
Parameters
----------
path : str | os.PathLike
Path to the grid
logger : logging.Logger
Logger instance for logging
log_level : str
Level of the logger
display_unit : WavelengthUnit
Unit of the wavelength to display
Returns
-------
ModelGrid
Instance of :class:~ModelGrid
Examples
--------
>>> grid = ModelGrid._from_file(path)
Notes
-----
Authors: Allan Denis
'''
logger = logger or setup_logging(level=log_level, name = 'ModelGrid')
logger.debug(f'Loading ModelGrid from file {path}')
ds = GridLoader._from_file(path)
logger.info(f' ModelGrid generated from {path}')
return cls(ds, model_path=path, logger=logger, display_unit=display_unit)
@classmethod
def _from_attributes(cls, data: np.ndarray, coords: dict, attrs: dict, logger: logging.Logger | None = None, log_level: str = 'INFO', display_unit: WavelengthUnit = WavelengthUnit.MICROMETER):
'''
Generate grid from attributes
Parameters
----------
data : np.ndarray
Array of the grid
coords : dict
Dictionnary of coordinates
attrs : dict
Dictionnary of attributes
logger : logging.Logger
Logger instance for logging
log_level : str
Level of the logger
display_unit : WavelengthUnit
Unit of the wavelength to display
Returns
-------
ModelGrid
Instance of :class:~ModelGrid
Examples
--------
>>> grid = ModelGrid._from_attributes(data, coords, attrs)
Notes
-----
Authors: Allan Denis
'''
logger = logger or setup_logging(level=log_level, name='ModelGrid')
logger.debug('<Generate ModelGrid from attributes>')
ds = GridLoader._from_data(data, coords, attrs)
logger.info('ModelGrid generated')
return cls(ds, logger=logger, display_unit=display_unit)
# ================================================
# Methods
# ================================================
def _set_unit(self, unit: WavelengthUnit):
'''
Method to set the unit used for the wavelength
Parameters
----------
unit : WavelengthUnit
unit used (micrometer', 'nanometer', 'angstrom')
Notes
-----
Authors: Allan Denis
'''
if not isinstance(unit, WavelengthUnit):
raise ForMoSAError(f"<Unit must be a WavelengthUnit Enum, got {type(unit)}>", self.logger)
self._logger.info(f"<Convert the unit used for grid {self.grid_name} from {self.native_unit} to {unit.unit}>")
self._display_unit = unit
def _load_model_at_specific_index(self, idx: tuple) -> xr.DataArray:
'''
Load a model at a specific index
Parameters
----------
idx : tuple
Index of model to be loaded (e.g. (5, 0, 6, 3) for a 4-parameter grid)
Returns
-------
model_to_return : xr.DataArray
Model at the specific index
Notes
-----
Authors: Allan Denis
'''
if not isinstance(idx, tuple):
raise ForMoSAError(f'<Index is type{type(idx)}. It should be a tuple>', self.logger)
model_to_return = self.grid_as_dataarray[(..., ) + idx]
if np.any(np.isnan(model_to_return)):
msg = 'Extraction of model failed : '
for i, (key, title) in enumerate(zip(self.keys, self.titles)):
msg += f'{title}={self.key_values[key][idx[i]]}, '
self._logger.warning(f' {msg}')
return model_to_return
def _interpolate_missing_values(self, method: str = "linear", limit: int = None, fill_value: str = 'extrapolate', max_gap: int = None) -> None:
'''
Interpolate missing (NaN) values in the grid.
Parameters
----------
method : str
Interpolation method to use.
limit : int
Maximum number of consecutive NaNs to fill.
fill_value : str
Method to fill in points outside of data range
max_gap : int
Maximum size of gap, a continuous sequence of NaNs, that will be filled
Notes
-----
Authors: Simon Petrus, Paulina Palma-Bifani, Matthieu Ravet and Allan Denis
'''
self.logger.info(f' Interpolating between holes of the grid {self.grid_name}')
interp_kwargs = {"method": method, "fill_value": fill_value}
if limit is not None:
interp_kwargs["limit"] = limit
if max_gap is not None:
interp_kwargs["max_gap"] = max_gap
self._logger.info(f" {self.grid_name}")
for idx, (key, title) in enumerate(zip(self.keys, self.titles)):
self.logger.info(f' {idx + 1}/{len(self.keys)} - {title}')
if bool(self.grid["grid"].isnull().any(dim=key).any().item()):
self._grid = self.grid.interpolate_na(dim=key, **interp_kwargs)
def _nan_interpolated_grid(self):
'''
Return a 1D interpolated grid filled with NaNs,
with the same structure as a valid interpolated grid.
Notes
-----
Authors: Allan Denis
'''
ref_params = {k: 0.5 * (self.lims_params_grid[k][0] + self.lims_params_grid[k][1]) for k in self.keys}
interp_kwargs = dict(ref_params)
interp_kwargs["method"] = "nearest"
interp_kwargs["kwargs"] = {"fill_value": np.nan}
grid_1d = self._grid.interp(**interp_kwargs)
# Replace values by nans
return grid_1d * np.nan
def _interpolate_between_gridpoints(self, grid_params: dict[str, float], method: str = "linear", print_logger: bool = False):
'''
Interpolate between gridpoints in the adapted spectroscopic and photometric grids.
Parameters
----------
grid_params : dict[str, float]
Dictionary of grid parameter values
method : str
Interpolation method
print_logger : bool
Whether to print logger info
Returns
-------
xr.Dataset
Interpolated 1D grid. If parameters are out-of-bounds, returns a 1D grid filled with NaNs.
Notes
-----
Authors: Simon Petrus, Paulina Palma-Bifani, Matthieu Ravet and Allan Denis
'''
if print_logger:
self.logger.info('Interpolate between gridpoints in the grid')
if not isinstance(grid_params, dict):
raise ForMoSAError(f'<Wrong type for grid_params: {type(grid_params)}. Expected a dictionary>', self.logger)
# ==================================================
# Keys validation
# ==================================================
expected_keys = set(self.keys)
provided_keys = set(grid_params.keys())
if provided_keys != expected_keys:
missing = expected_keys - provided_keys
extra = provided_keys - expected_keys
raise ForMoSAError(f"Grid parameter mismatch. Missing: {missing}, Extra: {extra}", self.logger)
# ==================================================
# Bounds checks
# ==================================================
out_of_bounds = False
for name, value in grid_params.items():
min_val, max_val = self.lims_params_grid[name]
if not (min_val <= value <= max_val):
out_of_bounds = True
self.logger.warning(f"Grid parameter '{name}'={value} outside bounds [{min_val}, {max_val}]. Returning NaN grid")
# ==================================================
# out-of-bounds
# ==================================================
if out_of_bounds:
return self._nan_interpolated_grid()
# ==================================================
# Normal interpolation
# ==================================================
interp_kwargs = dict(grid_params)
interp_kwargs["method"] = method
interp_kwargs["kwargs"] = {"fill_value": np.nan}
if not isinstance(self.grid, xr.Dataset):
raise ForMoSAError("Grid is not a valid xarray.Dataset", self.logger)
if print_logger:
self.logger.debug('<Interpolation>')
return self.grid.interp(**interp_kwargs)
[docs]
def save_grid(self, store_path : str | os.PathLike) -> None:
'''
Save the grid to a specified directory.
Parameters
----------
store_path : str | os.PathLike
Path where to store the grid
Notes
-----
Authors: Simon Petrus, Paulina Palma-Bifani, Matthieu Ravet and Allan Denis
'''
if not isinstance(store_path, (str, os.PathLike)):
raise ForMoSAError(f'<Wrong type for store_path: {type(store_path)}. Expected a string or os.PathLike>', self.logger)
store_path = Path(store_path).expanduser()
if not store_path.exists():
self.logger.warning(f'{store_path} does not exist. Creating it')
store_path.mkdir(exist_ok=True, parents=True)
self.logger.info(f" Saving Grid {self.suffix}_{self.grid_name}.nc to {store_path}")
filename = f"{self.suffix}_{self.grid_name}.nc"
self.grid.to_netcdf(store_path / filename, format="NETCDF4", engine="netcdf4", mode="w")
def _load_grid(self, store_path: str | os.PathLike, grid_name: str = 'in-memory-grid', suffix: str = 'native') -> xr.Dataset:
'''
Method to load a grid from a nc file
Parameters
----------
store_path : str | os.PathLike
Path where to store the grid
grid_name : str
Name of the grid
suffix : str
Suffix of the grid
Returns
-------
xr.Dataset
Loaded grid
Notes
-----
Authors: Simon Petrus, Paulina Palma-Bifani, Matthieu Ravet and Allan Denis
'''
if not isinstance(store_path, (str, os.PathLike)):
raise ForMoSAError(f'<Wrong type for store_path: {type(store_path)}. Expected a string or os.PathLike>', self.logger)
store_path = Path(store_path).expanduser()
self.logger.info(f'<Load adapted grid {suffix}_{grid_name} from the store_path {store_path}>')
filename = f"{suffix}_{grid_name}.nc"
grid_file = store_path / filename
self.logger.debug(f'<Open grid file {grid_file}>')
try:
grid = GridLoader._from_file(grid_file)
except ForMoSAError as e:
raise ForMoSAError(e, self.logger)
return grid
def _restricted_grid(self, windows: str | None = None, print_logger: bool=True, extension: float = 0.0) -> "ModelGrid":
'''
Returns a version of the grid restricted to the given wavelength range
Parameters
----------
windows : str
Windows in the format 'wmin1,wmax1 / wmin2,wmax2 / ...'
print_logger : bool
Whether to print the Logger
extension : float
Extension factor of the windows
Returns
-------
'ModelGrid'
An instance of class ModelGrid
Notes
-----
Authors: Allan Denis
'''
if windows is None:
windows = f'{self.wave[0]}, {self.wave[-1]}'
if not isinstance(windows, str):
raise ForMoSAError(f'<Wrong type for windows: {type(windows)}. Requires a string>', self.logger)
if print_logger:
self.logger.debug(f'Restrict grid {self.grid_name} onto windows {windows}')
ind = np.array([], dtype=int)
for window in windows.split("/"):
wmin, wmax = map(float, window.split(","))
indices = np.where((self.wave >= wmin * (1 - extension)) & (self.wave <= wmax * (1 + extension)))[0]
ind = np.concatenate((ind, indices))
ind = np.unique(ind)
restricted_grid = copy.deepcopy(self)
# New restriced grid
restricted_grid._grid = restricted_grid.grid.isel(wavelength=ind)
restricted_grid.grid._attrs['res'] = restricted_grid.res[ind]
# Final validation
try:
GridLoader._validate_model_grid_dataset(restricted_grid.grid)
except ForMoSAError as e:
raise ForMoSAError(e, self.logger)
if print_logger:
self.logger.info(f' Generated restricted Grid. Former grid length: {len(self.wave)}. New grid length: {len(restricted_grid.wave)}')
return restricted_grid