Source code for ForMoSA.grid.subgrid_base

import os
import logging
import numpy as np
import xarray as xr
from tqdm import tqdm
import astropy.units as u
import multiprocessing as mp
from functools import partial
from abc import ABC, abstractmethod
from multiprocessing.pool import ThreadPool

from ForMoSA.core.errors import ForMoSAError
from ForMoSA.grid.model_grid import ModelGrid
from ForMoSA.core.loggings import setup_logging
from ForMoSA.grid.grid_loader import GridLoader
from ForMoSA.parameter.parameter import Parameter
from ForMoSA.observation.observation_base import Observation
from ForMoSA.transform.observed import ObservedModel, ObservedParameters
from ForMoSA.core.enums import ObservationType, WavelengthUnit, ParameterKind, LogLikelihoodType


[docs] class SubGrid(ModelGrid, ABC): ''' Base class for any subgrid (spectroscopic or photometric). Inherits from the ModelGrid class. Parameters ---------- grid : xr.Dataset Grid parent_grid : ModelGrid Parent model grid target_resolution : str | float Target resolution to reach for the model ('obs', 'mod' or float) logger : logging.Logger | None Logger log_level : str Level of the Logger name : str Name of the subgrid display_unit : WavelengthUnit Unit of the wavelength Notes ----- Authors: Allan Denis ''' def __init__(self, grid: xr.Dataset, parent_grid: ModelGrid, logger: logging.Logger | None = None, log_level: str = "INFO", name: str = 'Unknown', display_unit: WavelengthUnit = WavelengthUnit.MICROMETER): # Initialization from the ModelGrid super().__init__(grid, logger=logger, log_level=log_level, display_unit=display_unit) self._name = name self._parent_grid = parent_grid self._grid.attrs['name'] = name self.logger.info(f' Setting wavelength unit of subgrid {self.name} to {self.unit}>') # ========================================= # Abstract methods # (Subclasses must implement these methods) # ========================================= @property @abstractmethod def GridType(self) -> ObservationType.obstype: """Grid type (spectroscopic or photometric).""" pass @property @abstractmethod def relevant_parameter_kinds(self) -> list[ParameterKind]: """List of relevant parameter kinds the subgrid applies to.""" return [ ParameterKind.ALPHA, ParameterKind.RADIUS, ParameterKind.DISTANCE, ParameterKind.AV, ParameterKind.BB_T, ParameterKind.BB_R ] @abstractmethod def _adapt_model(self, model: xr.DataArray) -> xr.DataArray: ''' Adapt a single model to the target wavelength and resolution. Optionally remove the continuum. Parameters ---------- model : xr.DataArray Model to adapt Returns ------- xr.DataArray Adapted model Notes ----- Authors: Allan Denis ''' pass @abstractmethod def _get_restriction_bounds(self) -> tuple[float, float]: ''' Get restriction bounds for computing the restricted subgrid before the adaptation of the subgrid. Returns ------- tuple[float, float] Minimumn and maximum wavelengths of the restricted subgrid Notes ----- Authors: Allan Denis ''' pass @classmethod @abstractmethod def _apply_physics_effects(model: ObservedModel, params: dict[Parameter, float]) -> ObservedModel: ''' Apply the relevant physics effects for the given model type. Parameters ---------- model : ObservedModel Instance of class ObservedModel to transform params : dict[Parameter, float] Dictionary of parameters values. Returns ------- ObservedModel Instance of class ObservedModel transformed Notes ----- Authors: Allan Denis ''' pass @classmethod @abstractmethod def _apply_observational_effects(model: ObservedModel, obs: Observation, bounds: tuple[float, float] | None = None) -> ObservedModel: ''' Apply the relevant observational effects for the given model type. Parameters ---------- observed_model : ObservedModel Instance of class ObservedModel to transform obs : Observation Instance of class Observation bounds : tuple[float, float] Bounds for the least squares Returns ------- ObservedModel Instance of class ObservedModel transformed Notes ----- Authors: Allan Denis ''' pass @staticmethod @abstractmethod def _compute_loglike(model: ObservedModel, obs: Observation, logL_type: LogLikelihoodType) -> float: ''' Compute the loglikelihood given a transformed model, an observation and a loglikelihood function Parameters ---------- model : ObservedModel Instance of class ObservedModel obs : Observation Observation logL_type : LogLikelihoodType Loglikelihood function Returns ------- float logL value Notes ----- Authors: Simon Petrus, Matthieu Ravet and Allan Denis ''' pass # ========================================= # Common properties # ========================================= @property def parent_grid(self) -> ModelGrid: """Parent grid.""" return self._parent_grid @property def suffix(self) -> str: """Suffix used for saving.""" return 'adapted' @property def name(self) -> str: """Name of the subgrid.""" return self._name @property def grid_name(self) -> str: """Name of the grid.""" return f'{self.parent_grid.grid_name}_{self.name}_{self.GridType}' @property def wavelength_range(self) -> tuple: """Wavelength range of the subgrid.""" return float(self.wave.min()), float(self.wave.max()) @property def unit(self) -> u.core.PrefixUnit: """Unit of the wavelength to display.""" return self._display_unit.unit @property def is_spectroscopic(self) -> bool: """Whether subgrid is spectroscopic.""" return self.GridType == ObservationType.SPECTROSCOPIC.obstype @property def is_photometric(self) -> bool: """Whether subgrid is photometric.""" return self.GridType == ObservationType.PHOTOMETRIC.obstype # ========================== # Class methods # ==========================
[docs] @classmethod def from_dataset(cls, ds: xr.Dataset, parent_grid: ModelGrid, logger: logging.Logger | None = None, log_level: str = 'INFO', display_unit: WavelengthUnit = WavelengthUnit.MICROMETER) -> 'SubGrid': ''' Generate SubGrid from dataset. Parameters ---------- ds : xr.Dataset Dataset containing the parameters of the subgrid parent_grid : ModelGrid Parent model grid logger : logging.Logger Logger log_level : str Level of the Logger Returns ------- 'SubGrid' An instance of class SubGrid Notes ----- Authors: Allan Denis ''' logger = logger or setup_logging(level=log_level, name="Observation") logger.debug('Extracting SubGrid from dataset') from ForMoSA.grid.subgrid_spectroscopy import SubGridSpectroscopy from ForMoSA.grid.subgrid_photometry import SubGridPhotometry grid_type = ds.attrs.get("grid_type") if grid_type is None: raise ForMoSAError("Dataset has no 'grid_type' attribute", logger) if grid_type == ObservationType.SPECTROSCOPIC.value: return SubGridSpectroscopy.from_grid(ds, parent_grid=parent_grid, logger=logger, display_unit=display_unit) elif grid_type == ObservationType.PHOTOMETRIC.value: return SubGridPhotometry.from_grid(ds, parent_grid=parent_grid, logger=logger, display_unit=display_unit) raise ForMoSAError(f"Unrecognized grid_type: {grid_type}. Expected {[grid_type.value for grid_type in ObservationType]}", logger)
[docs] @classmethod def from_file(cls, path: str | os.PathLike, parent_grid: ModelGrid, logger: logging.Logger | None = None, log_level: str = 'INFO', display_unit: WavelengthUnit = WavelengthUnit.MICROMETER) -> 'SubGrid': ''' Generate SubGrid from file. Parameters ---------- path : str | os.PathLike Path to the .nc file parent_grid : ModelGrid Parent model grid logger : logging.Logger Logger log_level : str Level of the Logger display_unit : WavelengthUnit Unit to display for wavelength Returns ------- 'SubGrid' An instance of class SubGrid Notes ----- Authors: Allan Denis ''' logger = logger or setup_logging(level=log_level, name='SubGrid') logger.debug(f'Loading grid from file {path}') try: ds = GridLoader._from_file(path) except ForMoSAError as e: raise ForMoSAError(e, logger) logger.info(f' {ds.attrs.get("grid_type")} grid detected') return cls.from_dataset(ds=ds, parent_grid=parent_grid, logger=logger, display_unit=display_unit)
# ========================== # Methods # ========================== def _build_empty_adapted_grid(self, target_wavelength: np.ndarray, target_resolution: np.ndarray) -> xr.Dataset: ''' Build en ampty grid from the parent grid and the targets wavelength and resolution Parameters ---------- target_wavelength : np.ndarray Wavelength to reach for the subgrid target_resolution : np.ndarray Resolution to reach for the subgrid Returns ------- xr.Dataset Empty grid Notes ----- Authors: Allan Denis ''' self._logger.debug(f'Building empty grid from the native grid {self.parent_grid.grid_name}') target_wavelength = np.atleast_1d(target_wavelength).astype(float) target_resolution = np.atleast_1d(target_resolution).astype(float) parent_wavelength = self.parent_grid.grid["wavelength"].values wl_min_parent = np.nanmin(parent_wavelength) wl_max_parent = np.nanmax(parent_wavelength) wl_min_target = np.nanmin(target_wavelength) wl_max_target = np.nanmax(target_wavelength) if (wl_min_target < wl_min_parent) or (wl_max_target > wl_max_parent): raise ForMoSAError(f"target_wavelength={target_wavelength} is outside the parent grid [{wl_min_parent}, {wl_max_parent}]>", self.logger) # Shape of the native grid data_shape = len(target_wavelength), *tuple(len(self.parent_grid.grid[key]) for key in self.parent_grid.keys) # Generate empty grid with the attributes and coordinates of the native grid empty_data = np.full(data_shape, np.nan) coords = {"wavelength": target_wavelength} coords.update({key: self.parent_grid.grid[key].values for key in self.parent_grid.keys}) attrs = self.parent_grid.attrs.copy() attrs["res"] = target_resolution try: # Create grid and check consistency between the parameters of the grid return GridLoader._from_data(empty_data, coords, attrs) except ForMoSAError as e: raise ForMoSAError(e, self.logger) from e self.logger.info(f" Initialized empty subgrid with shape {self._grid['grid'].shape}")
[docs] def adapt_grid(self) -> None: ''' Adapt the entire grid to the observation. Notes ----- Authors: Arthur Vigan and Allan Denis ''' # Get a restricted version of the parent grid to spped up the adaptation wmin, wmax = self._get_restriction_bounds() restricted_grid = self.parent_grid._restricted_grid(f'{wmin}, {wmax}') # Build indices shape = self.grid.grid.values.shape[1:] pbar = tqdm(total=np.prod(shape), leave=False) self._logger.debug(f'Adapting grid {self.grid_name}') self._logger.info(" Parallel adaptation") def update_result(result, idx): self._grid.grid[(..., ) + idx] = result pbar.update(1) # Parallel loop try: ncpu = mp.cpu_count() with ThreadPool(processes=ncpu) as pool: for idx in np.ndindex(shape): callback = partial(update_result, idx=idx) model_to_adapt = restricted_grid._load_model_at_specific_index(idx) pool.apply_async(self._adapt_model, args=(model_to_adapt,), callback=callback) pool.close() pool.join() except Exception as e: self._logger.warning(f"<Parallel adaptation failed: {e}. Falling back to serial mode>") # Non parallel loop try: for idx in tqdm(np.ndindex(shape)): model_to_adapt = restricted_grid._load_model_at_specific_index(idx) result = self._adapt_model(model_to_adapt) self._grid.grid[(..., ) + idx] = result except Exception as e: raise ForMoSAError(f'<Non parallel adaptation produced the following error: {e}>', self.logger)
def _compute_loglike_for_obs(self, observed_model: ObservedModel, observation: "Observation", logL_type: LogLikelihoodType = LogLikelihoodType.CHI2, interp_method: str = 'linear', bounds_lsq: tuple[float, float] = (-np.inf, np.inf)) -> float: ''' Evaluate the loglike given a dictionary of parameters and their associated values Parameters ---------- observed_model : ObservedModel Instance of class ObservedModel observation : Observation Observation logL_type LogLikelihoodType): Loglikelihood function interp_method : str Interpolation method bounds_lsq : tuple[float, float] (lower, higher) bounds for the Least Squares Returns ------- float Loglikelihood Notes ----- Authors: Allan Denis ''' # Initial check if not isinstance(observed_model, ObservedModel): raise ForMoSAError(f'Wrong type for observed_model: {type(observed_model)}. Expected an ObservedModel') if np.any(np.isnan(observed_model.flux)): # NaN are produced by out-of-grid parameters. The loglike must be -inf in this case return -float('inf') # Compute loglikelihood loglike = self._compute_loglike(observed_model, observation, logL_type = logL_type) return loglike def _build_model_from_params(self, observed_params: ObservedParameters, observation: Observation, interp_method: str = 'linear', bounds_lsq: tuple[float, float] = (-np.inf, np.inf)) -> ObservedModel: ''' Build a model from a dictionarty of parameters and their associated values Parameters ---------- observed_params ObservedParameters): Instance of class ObservedParameters observation : Observation Observation interp_method : str Interpolation method bounds : tuple[float, float] Bounds for the least squares Returns ------- ObservedModel Model build with the values of the parameters Notes ----- Authors: Allan Denis ''' # Split parameters grid_params = observed_params.grid physics_params = observed_params.physics # Evaluation of the model at grid points observed_model = self.evaluate_at_gridpoints(grid_params, interp_method = interp_method) if np.any(np.isnan(observed_model.flux)): # NaNs produced because grid parameter is outside of the grid bounds, so we directly return the model return observed_model # Apply physics transformations try: observed_model = self._apply_physics_effects(observed_model, physics_params) except ForMoSAError as e: raise ForMoSAError(e, self.logger) # Apply observational transformations try: observed_model = self._apply_observational_effects(observed_model, observation, bounds_lsq = bounds_lsq) except ForMoSAError as e: raise ForMoSAError(e, self.logger) return observed_model
[docs] def evaluate_at_gridpoints(self, grid_params: ObservedParameters, interp_method: str = 'linear') -> ObservedModel: ''' Evaluate model given a list of parameters and their associated values. Parameters ---------- params : Dict[Parameter, float] Dictionary of grid parameters and their associated values interp_method : str Interpolation method Returns ------- observed_model : ObservedModel Instance of class ObservedModel Notes ----- Authors: Allan Denis ''' # Initial checks if not isinstance(grid_params, ObservedParameters): raise ForMoSAError(f'<Wrong type for grid_params: {type(grid_params)}. Expected an ObservedParameters>', self.logger) if (not grid_params.has_grid) or (grid_params.has_physics): raise ForMoSAError("You should use only grid parameters in grid_params", self.logger) grid_values = grid_params.values_by_name # Inteprolation model = self._interpolate_between_gridpoints(grid_values, method = interp_method) observed_model = ObservedModel(model.coords['wavelength'], model.grid.values, model.attrs['res']) return observed_model