Source code for ForMoSA.grid.subgrid_base

import os
import logging
import traceback
import numpy as np
import xarray as xr
from tqdm import tqdm
import astropy.units as u
from abc import ABC, abstractmethod
from joblib import Parallel, delayed
from tqdm.contrib.logging import logging_redirect_tqdm

from ForMoSA.core.errors import ForMoSAError
from ForMoSA.grid.model_grid import ModelGrid
from ForMoSA.core.loggings import setup_logging
from ForMoSA.grid.grid_loader import GridLoader
from ForMoSA.parameter.parameter import Parameter
from ForMoSA.observation.observation_base import Observation
from ForMoSA.transform.observed import ObservedModel, ObservedParameters
from ForMoSA.core.enums import ObservationType, WavelengthUnit, ParameterKind, LogLikelihoodType


def _adapt_worker(subgrid, restricted_grid, idx):
    import logging
    # Silence the root logger in spawned worker processes to prevent direct stderr
    # writes that would interleave with tqdm in the main process.
    logging.disable(logging.WARNING)
    try:
        try:
            from threadpoolctl import threadpool_limits
            _ctx = threadpool_limits(limits=1, user_api='blas')
        except ImportError:
            from contextlib import nullcontext
            _ctx = nullcontext()
        with _ctx:
            model = restricted_grid._load_model_at_specific_index(idx)
            extraction_ok = not np.any(np.isnan(model.values))
            result = subgrid._adapt_model(model)
    finally:
        logging.disable(logging.NOTSET)
    return idx, result, extraction_ok


[docs] class SubGrid(ModelGrid, ABC): ''' Base class for any subgrid (spectroscopic or photometric). Inherits from the ModelGrid class. Parameters ---------- grid : xr.Dataset Grid parent_grid : ModelGrid Parent model grid target_resolution : str | float Target resolution to reach for the model ('obs', 'mod' or float) logger : logging.Logger | None Logger log_level : str Level of the Logger name : str Name of the subgrid display_unit : WavelengthUnit Unit of the wavelength Notes ----- Authors: Allan Denis ''' def __init__(self, grid: xr.Dataset, parent_grid: ModelGrid, logger: logging.Logger | None = None, log_level: str = "INFO", name: str = 'Unknown', display_unit: WavelengthUnit = WavelengthUnit.MICROMETER): # Initialization from the ModelGrid super().__init__(grid, logger=logger, log_level=log_level, display_unit=display_unit) self._name = name self._parent_grid = parent_grid self._grid.attrs['name'] = name self.logger.info(f' Setting wavelength unit of subgrid {self.name} to {self.unit}>') # ========================================= # Abstract methods # (Subclasses must implement these methods) # ========================================= @property @abstractmethod def GridType(self) -> ObservationType.obstype: """Grid type (spectroscopic or photometric).""" pass @property @abstractmethod def relevant_parameter_kinds(self) -> list[ParameterKind]: """List of relevant parameter kinds the subgrid applies to.""" return [ ParameterKind.ALPHA, ParameterKind.RADIUS, ParameterKind.DISTANCE, ParameterKind.AV, ParameterKind.BB_T, ParameterKind.BB_R ] @abstractmethod def _adapt_model(self, model: xr.DataArray) -> xr.DataArray: ''' Adapt a single model to the target wavelength and resolution. Optionally remove the continuum. Parameters ---------- model : xr.DataArray Model to adapt Returns ------- xr.DataArray Adapted model Notes ----- Authors: Allan Denis ''' pass @abstractmethod def _get_restriction_bounds(self) -> tuple[float, float]: ''' Get restriction bounds for computing the restricted subgrid before the adaptation of the subgrid. Returns ------- tuple[float, float] Minimumn and maximum wavelengths of the restricted subgrid Notes ----- Authors: Allan Denis ''' pass @classmethod @abstractmethod def _apply_physics_effects(model: ObservedModel, params: dict[Parameter, float]) -> ObservedModel: ''' Apply the relevant physics effects for the given model type. Parameters ---------- model : ObservedModel Instance of class ObservedModel to transform params : dict[Parameter, float] Dictionary of parameters values. Returns ------- ObservedModel Instance of class ObservedModel transformed Notes ----- Authors: Allan Denis ''' pass @classmethod @abstractmethod def _apply_observational_effects(model: ObservedModel, obs: Observation, bounds: tuple[float, float] | None = None) -> ObservedModel: ''' Apply the relevant observational effects for the given model type. Parameters ---------- observed_model : ObservedModel Instance of class ObservedModel to transform obs : Observation Instance of class Observation bounds : tuple[float, float] Bounds for the least squares Returns ------- ObservedModel Instance of class ObservedModel transformed Notes ----- Authors: Allan Denis ''' pass @staticmethod @abstractmethod def _compute_loglike(model: ObservedModel, obs: Observation, logL_type: LogLikelihoodType) -> float: ''' Compute the loglikelihood given a transformed model, an observation and a loglikelihood function Parameters ---------- model : ObservedModel Instance of class ObservedModel obs : Observation Observation logL_type : LogLikelihoodType Loglikelihood function Returns ------- float logL value Notes ----- Authors: Simon Petrus, Matthieu Ravet and Allan Denis ''' pass # ========================================= # Common properties # ========================================= @property def parent_grid(self) -> ModelGrid: """Parent grid.""" return self._parent_grid @property def suffix(self) -> str: """Suffix used for saving.""" return 'adapted' @property def name(self) -> str: """Name of the subgrid.""" return self._name @property def grid_name(self) -> str: """Name of the grid.""" return f'{self.parent_grid.grid_name}_{self.name}_{self.GridType}' @property def wavelength_range(self) -> tuple: """Wavelength range of the subgrid.""" return float(self.wave.min()), float(self.wave.max()) @property def unit(self) -> u.core.PrefixUnit: """Unit of the wavelength to display.""" return self._display_unit.unit @property def is_spectroscopic(self) -> bool: """Whether subgrid is spectroscopic.""" return self.GridType == ObservationType.SPECTROSCOPIC.obstype @property def is_photometric(self) -> bool: """Whether subgrid is photometric.""" return self.GridType == ObservationType.PHOTOMETRIC.obstype # ========================== # Class methods # ==========================
[docs] @classmethod def from_dataset(cls, ds: xr.Dataset, parent_grid: ModelGrid, logger: logging.Logger | None = None, log_level: str = 'INFO', display_unit: WavelengthUnit = WavelengthUnit.MICROMETER) -> 'SubGrid': ''' Generate SubGrid from dataset. Parameters ---------- ds : xr.Dataset Dataset containing the parameters of the subgrid parent_grid : ModelGrid Parent model grid logger : logging.Logger Logger log_level : str Level of the Logger Returns ------- 'SubGrid' An instance of class SubGrid Notes ----- Authors: Allan Denis ''' logger = logger or setup_logging(level=log_level, name="Observation") logger.debug('Extracting SubGrid from dataset') from ForMoSA.grid.subgrid_spectroscopy import SubGridSpectroscopy from ForMoSA.grid.subgrid_photometry import SubGridPhotometry grid_type = ds.attrs.get("grid_type") if grid_type is None: raise ForMoSAError("Dataset has no 'grid_type' attribute", logger) if grid_type == ObservationType.SPECTROSCOPIC.value: return SubGridSpectroscopy.from_grid(ds, parent_grid=parent_grid, logger=logger, display_unit=display_unit) elif grid_type == ObservationType.PHOTOMETRIC.value: return SubGridPhotometry.from_grid(ds, parent_grid=parent_grid, logger=logger, display_unit=display_unit) raise ForMoSAError(f"Unrecognized grid_type: {grid_type}. Expected {[grid_type.value for grid_type in ObservationType]}", logger)
[docs] @classmethod def from_file(cls, path: str | os.PathLike, parent_grid: ModelGrid, logger: logging.Logger | None = None, log_level: str = 'INFO', display_unit: WavelengthUnit = WavelengthUnit.MICROMETER) -> 'SubGrid': ''' Generate SubGrid from file. Parameters ---------- path : str | os.PathLike Path to the .nc file parent_grid : ModelGrid Parent model grid logger : logging.Logger Logger log_level : str Level of the Logger display_unit : WavelengthUnit Unit to display for wavelength Returns ------- 'SubGrid' An instance of class SubGrid Notes ----- Authors: Allan Denis ''' logger = logger or setup_logging(level=log_level, name='SubGrid') logger.debug(f'Loading grid from file {path}') try: ds = GridLoader._from_file(path) except ForMoSAError as e: raise ForMoSAError(e, logger) logger.info(f' {ds.attrs.get("grid_type")} grid detected') return cls.from_dataset(ds=ds, parent_grid=parent_grid, logger=logger, display_unit=display_unit)
# ========================== # Methods # ========================== def _build_empty_adapted_grid(self, target_wavelength: np.ndarray, target_resolution: np.ndarray) -> xr.Dataset: ''' Build en ampty grid from the parent grid and the targets wavelength and resolution Parameters ---------- target_wavelength : np.ndarray Wavelength to reach for the subgrid target_resolution : np.ndarray Resolution to reach for the subgrid Returns ------- xr.Dataset Empty grid Notes ----- Authors: Allan Denis ''' self._logger.debug(f'Building empty grid from the native grid {self.parent_grid.grid_name}') target_wavelength = np.atleast_1d(target_wavelength).astype(float) target_resolution = np.atleast_1d(target_resolution).astype(float) parent_wavelength = self.parent_grid.grid["wavelength"].values wl_min_parent = np.nanmin(parent_wavelength) wl_max_parent = np.nanmax(parent_wavelength) wl_min_target = np.nanmin(target_wavelength) wl_max_target = np.nanmax(target_wavelength) if (wl_min_target < wl_min_parent) or (wl_max_target > wl_max_parent): raise ForMoSAError(f"target_wavelength={target_wavelength} is outside the parent grid [{wl_min_parent}, {wl_max_parent}]>", self.logger) # Shape of the native grid data_shape = len(target_wavelength), *tuple(len(self.parent_grid.grid[key]) for key in self.parent_grid.keys) # Generate empty grid with the attributes and coordinates of the native grid empty_data = np.full(data_shape, np.nan) coords = {"wavelength": target_wavelength} coords.update({key: self.parent_grid.grid[key].values for key in self.parent_grid.keys}) attrs = self.parent_grid.attrs.copy() attrs["res"] = target_resolution try: # Create grid and check consistency between the parameters of the grid return GridLoader._from_data(empty_data, coords, attrs) except ForMoSAError as e: raise ForMoSAError(e, self.logger) from e self.logger.info(f" Initialized empty subgrid with shape {self._grid['grid'].shape}")
[docs] def adapt_grid(self, backend: str = 'loky', n_jobs: int = -1) -> None: ''' Adapt the entire grid to the observation. Parameters ---------- backend : str Joblib parallel backend. Built-in options: 'loky' (default), 'multiprocessing', 'threading', 'sequential'. Third-party: 'dask', 'ray'. n_jobs : int Number of parallel jobs. -1 uses all available CPUs. Passed to joblib.Parallel. Notes ----- Authors: Arthur Vigan and Allan Denis ''' # Get a restricted version of the parent grid to spped up the adaptation wmin, wmax = self._get_restriction_bounds() restricted_grid = self.parent_grid._restricted_grid(f'{wmin}, {wmax}') # Build indices shape = self.grid.grid.values.shape[1:] indices = list(np.ndindex(shape)) total = len(indices) self._logger.debug(f'Adapting grid {self.grid_name}') self._logger.info(f" Parallel adaptation (backend='{backend}', n_jobs={n_jobs})") try: with logging_redirect_tqdm(loggers=[logging.getLogger('ForMoSA')]): with tqdm(total=total, leave=True, desc=self.grid_name, unit='model') as pbar: for idx, result, extraction_ok in Parallel(n_jobs=n_jobs, backend=backend, return_as='generator_unordered')( delayed(_adapt_worker)(self, restricted_grid, idx) for idx in indices ): if not extraction_ok: msg = 'Extraction of model failed : ' for i, (key, title) in enumerate(zip(restricted_grid.keys, restricted_grid.titles)): msg += f'{title}={restricted_grid.key_values[key][idx[i]]}, ' self._logger.warning(msg) self._grid.grid[(...,) + idx] = result pbar.update(1) except Exception: self._logger.warning("Parallel adaptation failed, fallback to serial\n" + traceback.format_exc()) try: with logging_redirect_tqdm(loggers=[logging.getLogger('ForMoSA')]): for idx in tqdm(indices, leave=True, desc=self.grid_name, unit='model'): model = restricted_grid._load_model_at_specific_index(idx) result = self._adapt_model(model) self._grid.grid[(...,) + idx] = result except Exception as e: self._logger.error(f"Non parallel adaptation produced the following error: {e}")
def _compute_loglike_for_obs(self, observed_model: ObservedModel, observation: "Observation", logL_type: LogLikelihoodType = LogLikelihoodType.CHI2, interp_method: str = 'linear', bounds_lsq: tuple[float, float] = (-np.inf, np.inf)) -> float: ''' Evaluate the loglike given a dictionary of parameters and their associated values Parameters ---------- observed_model : ObservedModel Instance of class ObservedModel observation : Observation Observation logL_type LogLikelihoodType): Loglikelihood function interp_method : str Interpolation method bounds_lsq : tuple[float, float] (lower, higher) bounds for the Least Squares Returns ------- float Loglikelihood Notes ----- Authors: Allan Denis ''' # Initial check if not isinstance(observed_model, ObservedModel): raise ForMoSAError(f'Wrong type for observed_model: {type(observed_model)}. Expected an ObservedModel') if np.any(np.isnan(observed_model.flux)): # NaN are produced by out-of-grid parameters. The loglike must be -inf in this case return -float('inf') # Compute loglikelihood loglike = self._compute_loglike(observed_model, observation, logL_type = logL_type) return loglike def _build_model_from_params(self, observed_params: ObservedParameters, observation: Observation, interp_method: str = 'linear', bounds_lsq: tuple[float, float] = (-np.inf, np.inf)) -> ObservedModel: ''' Build a model from a dictionarty of parameters and their associated values Parameters ---------- observed_params ObservedParameters): Instance of class ObservedParameters observation : Observation Observation interp_method : str Interpolation method bounds : tuple[float, float] Bounds for the least squares Returns ------- ObservedModel Model build with the values of the parameters Notes ----- Authors: Allan Denis ''' # Split parameters grid_params = observed_params.grid physics_params = observed_params.physics # Evaluation of the model at grid points observed_model = self.evaluate_at_gridpoints(grid_params, interp_method = interp_method) if np.any(np.isnan(observed_model.flux)): # NaNs produced because grid parameter is outside of the grid bounds, so we directly return the model return observed_model # Apply physics transformations try: observed_model = self._apply_physics_effects(observed_model, physics_params) except ForMoSAError as e: raise ForMoSAError(e, self.logger) # Apply observational transformations try: observed_model = self._apply_observational_effects(observed_model, observation, bounds_lsq = bounds_lsq) except ForMoSAError as e: raise ForMoSAError(e, self.logger) return observed_model
[docs] def evaluate_at_gridpoints(self, grid_params: ObservedParameters, interp_method: str = 'linear') -> ObservedModel: ''' Evaluate model given a list of parameters and their associated values. Parameters ---------- params : Dict[Parameter, float] Dictionary of grid parameters and their associated values interp_method : str Interpolation method Returns ------- observed_model : ObservedModel Instance of class ObservedModel Notes ----- Authors: Allan Denis ''' # Initial checks if not isinstance(grid_params, ObservedParameters): raise ForMoSAError(f'<Wrong type for grid_params: {type(grid_params)}. Expected an ObservedParameters>', self.logger) if (not grid_params.has_grid) or (grid_params.has_physics): raise ForMoSAError("You should use only grid parameters in grid_params", self.logger) grid_values = grid_params.values_by_name # Inteprolation model = self._interpolate_between_gridpoints(grid_values, method = interp_method) observed_model = ObservedModel(model.coords['wavelength'], model.grid.values, model.attrs['res']) return observed_model