Source code for ForMoSA.config.global_config

import os
import ast
import logging
import numpy as np
from pathlib import Path
from configobj import ConfigObj
from typing import Any, List, Union
from scipy.interpolate import interp1d
from dataclasses import dataclass, field, asdict

import ForMoSA.utils.misc as um
import ForMoSA.parameter.prior as Prior
from ForMoSA.core.errors import ForMoSAError
from ForMoSA.grid.model_grid import ModelGrid
from ForMoSA.core.loggings import setup_logging
from ForMoSA.parameter.parameter import Parameter
from ForMoSA.observation.observation_set import ObservationSet
from ForMoSA.core.enums import PriorType, VsiniFunction, ParameterKind, ObservationType

# ----------------------------
# Dataclasses for each section
# ----------------------------

[docs] @dataclass class ConfigPath: observation_path: list[str | os.PathLike] = field(default_factory=list[str]) adapt_store_path: str | os.PathLike = field(default_factory=str) result_path: str | os.PathLike = field(default_factory=str) model_path: str | os.PathLike = field(default_factory=str)
[docs] def __post_init__(self) -> None: ''' Check paths types. Notes ----- Authors: Allan Denis ''' for name in ['adapt_store_path', 'result_path', 'model_path']: value = getattr(self, name) if not isinstance(value, (str, os.PathLike)): raise ForMoSAError(f"{name} must be str or os.PathLike, got {type(value)}") if isinstance(self.observation_path, (str, os.PathLike)): setattr(self, 'observation_path', [self.observation_path]) if not isinstance(self.observation_path, list): raise ForMoSAError(f"Wrong type for observation_path: {type(self.observation_path)}. Expected a list") if not all(isinstance(obs_path, (str | os.PathLike)) for obs_path in self.observation_path): raise ForMoSAError("observation_path must be a list of str or os.PathLike")
[docs] def set_params(self, **kwargs) -> None: ''' Update paths Parameters ---------- **kwargs : dict Keyword arguments to override attributes of the config Notes ----- Authors: Allan Denis ''' for key, value in kwargs.items(): if not hasattr(self, key): raise ForMoSAError(f'Unknown ConfigPath key: {key}') setattr(self, key, value) # Validation of the format self.__post_init__()
_JOBLIB_BACKENDS = ('loky', 'multiprocessing', 'threading', 'sequential', 'dask', 'ray')
[docs] @dataclass class ConfigAdapt: method: str = "linear" emulator: list[str] = field(default_factory=lambda: ["NA"]) target_res_obs: list[Union[str, float]] = field(default_factory=lambda: ["obs"]) target_res_mod: list[Union[str, float]] = field(default_factory=lambda: ["obs"]) wav_cont: list[Union[str, float]] = field(default_factory=lambda: ["NA"]) res_cont: list[Union[str, float]] = field(default_factory=lambda: ["NA"]) backend: str = "loky" n_jobs: int = 4
[docs] def __post_init__(self) -> None: ''' Check adapt configuration parameters and normalize types. Notes ----- Authors: Allan Denis ''' # Check method if not isinstance(self.method, str): raise ForMoSAError(f" method must be a string, got {type(self.method)}") # Check backend if not isinstance(self.backend, str): raise ForMoSAError(f" backend must be a string, got {type(self.backend)}") if self.backend not in _JOBLIB_BACKENDS: raise ForMoSAError(f" backend must be one of {_JOBLIB_BACKENDS}, got '{self.backend}'") # Coerce n_jobs from INI scalar strings before type checks. if isinstance(self.n_jobs, str): try: self.n_jobs = int(self.n_jobs.strip()) except ValueError as exc: raise ForMoSAError( f" n_jobs must be -1 (all CPUs) or a positive integer, got {self.n_jobs}" ) from exc elif isinstance(self.n_jobs, float) and self.n_jobs.is_integer(): self.n_jobs = int(self.n_jobs) # Check n_jobs if isinstance(self.n_jobs, bool) or not isinstance(self.n_jobs, int) or self.n_jobs == 0 or self.n_jobs < -1: raise ForMoSAError(f" n_jobs must be -1 (all CPUs) or a positive integer, got {self.n_jobs}") # Normalize fields self.emulator = um.normalize_list(self.emulator, "emulator") self.target_res_obs = um.normalize_list(self.target_res_obs, "target_res_obs", um.to_float_if_possible) self.target_res_mod = um.normalize_list(self.target_res_mod, "target_res_mod", um.to_float_if_possible) self.wav_cont = um.normalize_list(self.wav_cont, "wav_cont", um.to_float_if_possible) self.res_cont = um.normalize_list(self.res_cont, "res_cont", um.to_float_if_possible)
# ======================= # Methods # ======================= def _check_with_n_obs(self, n_obs: int) -> None: ''' Check consistency between lengths of list parameters and n_obs Parameters ---------- n_obs : int Number to be tested against lengths of list parameters Notes ----- Authors: Allan Denis ''' if (not isinstance(n_obs, int)) or (n_obs < 1): raise ForMoSAError(f' n_obs ({n_obs}) must be an integer greater than 0') for name in ['target_res_obs', 'target_res_mod', 'wav_cont', 'res_cont']: value = getattr(self, name) if len(value) > 1: if len(value) != n_obs: raise ForMoSAError(f' Number of observations ({n_obs}) and length of {name} ({len(value)}) are inconsistent') elif len(value) == 1 and n_obs > 1: setattr(self, name, n_obs * value) def _compute_obs_target_resolution(self, observations: ObservationSet, grid: ModelGrid) -> list[np.ndarray]: ''' Compute the target spectral resolution of the observation with respect to the model grid. Parameters ---------- observations : ObservationSet Instance of class ObservationSet grid : ModelGrid Instance of class ModelGrid Returns ------- list[np.ndarray] Target resolution for each observation Notes ----- Authors: Simon Petrus, Matthieu Ravet and Allan Denis ''' # Initial check if not isinstance(observations, ObservationSet): raise ForMoSAError(f'Wrong type for observations: {type(observations)}. Expected an instance of ObservationSet') if not isinstance(grid, ModelGrid): raise ForMoSAError(f'Wrong type for grid: {type(grid)}. Expected an instance of ModelGrid') # Check that len(config_adapt.target_res_obs) is consistent with the number of observations self._check_with_n_obs(observations.n_observations) # Compute target resolution output_target_res = [] # Loop in observations for target_res, obs in zip(self.target_res_obs, observations.observations): if obs.ObsType == ObservationType.PHOTOMETRIC.obstype: output_target_res.append(0) elif obs.ObsType == ObservationType.SPECTROSCOPIC.obstype: # Interpolate model resolution onto observation wavelength grid interp_model_to_obs = interp1d(grid.wave, grid.res, fill_value="extrapolate") res_model_obs = interp_model_to_obs(obs.wave) if target_res == "obs": target_res = np.minimum(obs.res, res_model_obs) else: res_custom = np.full_like(obs.res, float(target_res)) target_res = np.minimum.reduce([obs.res, res_model_obs, res_custom]) output_target_res.append(target_res) return output_target_res def _compute_model_target_wavelength_and_resolution(self, observations: ObservationSet, grid: ModelGrid) -> tuple[list[np.ndarray], list[np.ndarray]]: ''' Determine the target wavelength and resolution for the grids. Parameters ---------- observations : ObservationSet Instance of class ObservationSet grid : ModelGrid Instance of class ModelGrid Returns ------- tuple[list[np.ndarray], list[np.ndarray]]: Targets wavelength and resolution lists Notes ----- Authors: Allan Denis ''' # Initial checks if not isinstance(observations, ObservationSet): raise ForMoSAError(f'Wrong type for observations: {type(observations)}. Expected an instance of ObservationSet') if not isinstance(grid, ModelGrid): raise ForMoSAError(f'Wrong type for grid: {type(grid)}. Expected an instance of ModelGrid') # Check that len(config_adapt.target_res_mod) is consistent with the number of observations self._check_with_n_obs(observations.n_observations) # Compute targets wavelength and resolution output_target_wave, output_target_res = [], [] # Loop in observations for target_res, obs in zip(self.target_res_mod, observations.observations): if target_res == "obs": target_wave = obs.wave.copy() target_res = obs.res.copy() else: target_wave = grid.wave.copy() if target_res == "mod": target_res = grid.res.copy() else: # target_res is a float target_res = np.full(target_wave.shape, float(target_res), dtype=float) output_target_wave.append(target_wave) output_target_res.append(target_res) return output_target_wave, output_target_res def _determine_remove_continuum(self, observations: ObservationSet) -> list[bool]: ''' Determine whether the continuum should be removed for each observation. Parameters ---------- observations : ObservationSet Instance of class ObservationSet Returns ------- list[bool] Whether the continuum should be removed for each observation Notes ----- Authors: Allan Denis ''' # Initial checks if not isinstance(observations, ObservationSet): raise ForMoSAError(f'Wrong type for observations: {type(observations)}. Expected an instance of ObservationSet') # Check that len(config_adapt.res_cont) is consistent with the number of observations self._check_with_n_obs(observations.n_observations) # Compute remove continuum remove_continuum = [] # Loop in observations for res_cont, obs in zip(self.res_cont, observations.observations): # No continuum windows → no continuum removal if res_cont == "NA": remove_cont = False # High-contrast mode → never remove continuum from models elif obs.hc_mode: remove_cont = False # Not high-contrast mode → remove continuum from models else: remove_cont = True remove_continuum.append(remove_cont) return remove_continuum
[docs] def set_params(self, **kwargs) -> None: ''' Update parameters of ConfigAdapt Parameters ---------- **kwargs : dict Keyword arguments to override attributes of the config Notes ----- Authors: Allan Denis ''' for key, value in kwargs.items(): if not hasattr(self, key): raise ForMoSAError(f'Unknown ConfigPath key: {key}') setattr(self, key, value) # Validation of the format self.__post_init__()
[docs] @dataclass class ConfigInversion: logL_type: List[str] = field(default_factory=lambda: ["chi2"]) wav_fit: List[str] = field(default_factory=lambda: ["0.9, 5.0"]) ns_algo: str = field(default_factory=lambda: "pymultinest") npoints: int = field(default_factory=lambda: 50) hc_lower_bounds_lsq: list[float | str] = field(default_factory=lambda: ["NA"]) hc_higher_bounds_lsq: list[float | str] = field(default_factory=lambda: ["NA"])
[docs] def __post_init__(self) -> None: ''' Check inversion configuration parameters and normalize types. Notes ----- Authors: Allan Denis ''' # Check npoints if isinstance(self.npoints, str): self.npoints = int(self.npoints) if not isinstance(self.npoints, int) or self.npoints <= 0: raise ForMoSAError('npoints must be a strictly positive integer') # Check ns_algo if not isinstance(self.ns_algo, str): raise ForMoSAError(f'ns_algo must be a string, got {type(self.ns_algo)}') # Normalize fields self.logL_type = um.normalize_list(self.logL_type, "logL_type") self.wav_fit = um.normalize_list(self.wav_fit, "wav_fit", um.to_float_if_possible) self.hc_lower_bounds_lsq = um.normalize_list(self.hc_lower_bounds_lsq, "hc_lower_bounds_lsq", um.to_float_if_possible) self.hc_higher_bounds_lsq = um.normalize_list(self.hc_higher_bounds_lsq, "hc_higher_bounds_lsq", um.to_float_if_possible) # Check lower and higher hc bounds if len(self.hc_lower_bounds_lsq) != len(self.hc_higher_bounds_lsq): raise ForMoSAError('hc_lower_bounds_lsq and hc_higher_bounds_lsq must have same lengths')
# ======================= # Properties # ======================= @property def hc_bounds(self) -> list[tuple[float, float]]: bounds = [] for lower, higher in zip(self.hc_lower_bounds_lsq, self.hc_higher_bounds_lsq): # Replace 'NA' for lower bound if isinstance(lower, str) and lower.upper() == "NA": lower = -float('inf') # Replace 'NA' for higher bound if isinstance(higher, str) and higher.upper() == "NA": higher = float('inf') bounds.append((lower, higher)) return bounds # ======================= # Methods # ======================= def _check_with_n_obs(self, n_obs: int) -> None: ''' Check consistency between lengths of list parameters and n_obs Parameters ---------- n_obs : int Number to be tested against lengths of list parameters Notes ----- Authors: Allan Denis ''' if (not isinstance(n_obs, int)) or (n_obs < 1): raise ForMoSAError(f' n_obs ({n_obs}) must be an integer greater than 0') for name in ['logL_type', 'wav_fit', 'hc_lower_bounds_lsq', 'hc_higher_bounds_lsq']: value = getattr(self, name) if len(value) > 1: if len(value) != n_obs: raise ForMoSAError(f' Number of observations ({n_obs}) and length of {name} ({len(value)}) are inconsistent') elif len(value) == 1 and n_obs > 1: setattr(self, name, n_obs * value)
[docs] def set_params(self, **kwargs) -> None: ''' Update parameters of ConfigInversion Parameters ---------- **kwargs : dict Keyword arguments to override attributes of the config Notes ----- Authors: Allan Denis ''' for key, value in kwargs.items(): if not hasattr(self, key): raise ForMoSAError(f'Unknown ConfigPath key: {key}') setattr(self, key, value) # Validation of the format self.__post_init__()
[docs] @dataclass class ConfigParameters: par1: list[str] = field(default_factory=lambda: ["NA"]) par2: list[str] = field(default_factory=lambda: ["NA"]) par3: list[str] = field(default_factory=lambda: ["NA"]) par4: list[str] = field(default_factory=lambda: ["NA"]) r: list[str] = field(default_factory=lambda: ["NA"]) d: list[str] = field(default_factory=lambda: ["NA"]) alpha: list[str] = field(default_factory=lambda: ["NA"]) bb_T: list[str] = field(default_factory=lambda: ["NA"]) bb_R: list[str] = field(default_factory=lambda: ["NA"]) rv: list[str] = field(default_factory=lambda: ["NA"]) vsini: list[str] = field(default_factory=lambda: ["NA"]) ld: list[str] = field(default_factory=lambda: ["NA"])
[docs] def __post_init__(self) -> None: ''' Check parameters configuration. Notes ----- Authors: Allan Denis ''' for name, value in self.__dict__.items(): if not isinstance(value, list): try: value = ast.literal_eval(value) except Exception as e: raise ForMoSAError(f"{name} cannot be parsed as list") from e setattr(self, name, value) if not all(isinstance(v, str) for v in value): wrong_types = {type(v).__name__ for v in value if not isinstance(v, str)} raise ForMoSAError(f"{name} must contain only strings. (got {wrong_types})")
# ======================= # Representation # ======================= def __repr__(self) -> str: params = ", ".join(f"{key}={value!r}" for key, value in self.to_dict.items() if key.startswith('par')) + ", " + ", ".join(f"{key}={value!r}" for key, value in self.__dict__.items() if not key.startswith('par')) return f"ConfigParameters({params})" # ======================= # Properties # ======================= @property def to_dict(self) -> dict: """Dictionary representation of ConfigParameters.""" return self.__dict__ # ======================= # Methods # ======================= def _add_parameter(self, param: str, value: list[str]) -> None: ''' Add a parameter to the instance. Parameters ---------- param : str Name of the parameter value : list[str] Value of the parameter Notes ----- Authors: Allan Denis ''' setattr(self, str(param), value) self.__post_init__() def _parse_param_name(self, name: str) -> tuple[str, str, list | None]: ''' Parse parameter name from the current instance. Parameters ---------- name : str Name of the parameter in the config file Returns ------- tuple[str, str, list | None] (name, scope, obs_index) of the parameter Notes ----- Authors: Allan Denis ''' if name not in self.to_dict.keys(): raise ForMoSAError(f' Please chose a name amongst {self.to_dict.keys()}') if "_" not in name or name in ('bb_T', 'bb_R'): if name.startswith('par'): kind = ParameterKind.GRID else: kind = ParameterKind[name.upper()] return name, kind, "global", None base, obs = name.split("_", 1) obs_index = [int(i) for i in obs.split("_")] kind = ParameterKind[base.upper()] return name, kind, "local", obs_index def _parse_param_value(self, value: list[str] | str) -> Prior.Prior | str: ''' Parse parameter value from the current instance. Parameters ---------- value : list[str] Value of the parameter in the config file Returns ------- Prior.Prior | str Instance of Prior.Prior or 'NA' Notes ----- Authors: Allan Denis ''' if value not in self.to_dict.values(): raise ForMoSAError(f' Please chose a value amongst {self.to_dict.value()}') if value[0] == 'NA' or value == 'NA': return 'NA' else: prior_type = PriorType[value[0].strip().upper()] if prior_type == PriorType.CONSTANT: prior = Prior.ConstantPrior(value[1].strip()) elif prior_type == PriorType.UNIFORM: prior = Prior.UniformPrior(value[1].strip(), value[2].strip()) elif prior_type == PriorType.LOG_UNIFORM: prior = Prior.LogUniformPrior(value[1].strip(), value[2].strip()) else: prior = Prior.GaussianPrior(value[1].strip(), value[2].strip()) return prior def _parse_param(self, name: str, value: str, **kwargs) -> Parameter | None: ''' Parse parameter name and value from the current instance. Parameters ---------- name : str Name of the parameter in the config file value : str Value of the parameter in the config file **kwargs : Additional arguments (logger, log_level, ...) Returns ------- Parameter | None Instance of class Parameter or None if parameter has no prior Notes ----- Authors: Allan Denis ''' name, kind, scope, obs_index = self._parse_param_name(name) prior = self._parse_param_value(value) if prior != 'NA': vsini_function = None if kind == ParameterKind.VSINI: vsini_functions = [str(func.function) for func in VsiniFunction] vsini_function = value[-1].strip() if not vsini_function in vsini_functions: raise ForMoSAError(f'Wrong value for vsini_function: {vsini_function}. You must provide a valid vsini_function amongst {vsini_functions}') vsini_function = VsiniFunction[vsini_function] return Parameter(name, prior, kind, scope=scope, obs_index=obs_index, vsini_function=vsini_function, **kwargs) # Prior is NA return None
[docs] def set_params(self, **kwargs) -> None: ''' Update parameters of ConfigParameters Parameters ---------- **kwargs : dict Keyword arguments to override attributes of the config Notes ----- Authors: Allan Denis ''' for key, value in kwargs.items(): if not hasattr(self, key): raise ForMoSAError(f'Unknown ConfigPath key: {key}') setattr(self, key, value) # Validation of the format self.__post_init__()
[docs] @dataclass class ConfigNestle: method: str = field(default_factory=lambda: "single") update_interval: float = field(default_factory=lambda: None) npdim: int = field(default_factory=lambda: None) maxiter: int = field(default_factory=lambda: None) maxcall: int = field(default_factory=lambda: None) dlogz: float = field(default_factory=lambda: None) decline_factor: float = field(default_factory=lambda: None) rstate: Any = field(default_factory=lambda: None) # seed or RNG, can be int, list, or numpy.random.Generator
[docs] def __post_init__(self) -> None: ''' Check Nestle configuration parameters and normalize types. Notes ----- Authors: Allan Denis ''' # Check method if not isinstance(self.method, str): raise ForMoSAError(f" method must be a string, got {type(self.method)}") # Convert string numbers to float/int where appropriate int_fields = ["update_interval", "npdim", "maxiter", "maxcall"] float_fields = ["dlogz", "decline_factor"] for name in int_fields: value = getattr(self, name) if not isinstance(value, int): if isinstance(value, str): try: setattr(self, name, int(value)) except ValueError: if value == 'None': setattr(self, name, None) else: raise ForMoSAError(f" {name} must be int or str convertible to int, got '{value}'") for name in float_fields: value = getattr(self, name) if not isinstance(value, float): if isinstance(value, str): try: setattr(self, name, float(value)) except ValueError: if value == 'None': setattr(self, name, None) else: raise ForMoSAError(f" {name} must be float or str convertible to float, got '{value}'") # rstate can be None, int, or RNG, no need to convert if self.rstate != 'None' and not isinstance(self.rstate, (int, list)): pass # allow user to pass np.random.Generator or custom RNG if self.rstate == 'None': setattr(self, 'rstate', None)
# ======================= # Properties # ======================= @property def to_dict(self) -> dict: """Dictionary representation of ConfigNestle.""" return self.__dict__ # ======================= # Methods # =======================
[docs] @classmethod def from_dict(cls, data: dict) -> 'ConfigNestle': ''' Build an instance of ConfigNestle from dictionary of Nestle parameters. Parameters ---------- data : dict Dictionary of Nestle parameters Returns ------- 'ConfigNestle' An instance of class ConfigNestle Notes ----- Authors: Allan Denis ''' if not isinstance(data, dict): raise ForMoSAError(f'Wrong type for data: {type(data)}. Expected a dictionary') return cls(**data['nestle'])
[docs] def set_params(self, **kwargs) -> None: ''' Update parameters of ConfigNestle Parameters ---------- **kwargs : dict Keyword arguments to override attributes of the config Notes ----- Authors: Allan Denis ''' for key, value in kwargs.items(): if not hasattr(self, key): raise ForMoSAError(f'Unknown ConfigPath key: {key}') setattr(self, key, value) # Validation of the format self.__post_init__()
[docs] @dataclass class ConfigPyMultiNest: importance_nested_sampling: bool = field(default_factory=lambda: True) multimodal: bool = field(default_factory=lambda: True) const_efficiency_mode: bool = field(default_factory=lambda: False) evidence_tolerance: float = field(default_factory=lambda: 0.5) sampling_efficiency: float = field(default_factory=lambda: 0.8) n_iter_before_update: int = field(default_factory=lambda: 100) null_log_evidence: float = field(default_factory=lambda: -1e90) max_modes: int = field(default_factory=lambda: 100) mode_tolerance: float = field(default_factory=lambda: -1e90) seed: int = field(default_factory=lambda: -1) verbose: bool = field(default_factory=lambda: True) resume: bool = field(default_factory=lambda: False) context: int = field(default_factory=lambda: 0) log_zero: float = field(default_factory=lambda: -1e100) max_iter: int = field(default_factory=lambda: 0) init_MPI: bool = field(default_factory=lambda: False) wrapped_params: int = field(default_factory=lambda: None) dump_callback: int = field(default_factory=lambda: None) use_MPI: bool = field(default_factory=lambda: True)
[docs] def __post_init__(self) -> None: ''' Check PyMultiNest configuration parameters and normalize types. Notes ----- Authors: Allan Denis ''' # Bool fields bool_fields = ( "importance_nested_sampling", "multimodal", "const_efficiency_mode", "verbose", "resume", "init_MPI", "use_MPI", ) for name in bool_fields: value = getattr(self, name) if isinstance(value, str): val_lower = value.lower() if val_lower == "true": setattr(self, name, True) elif val_lower == "false": setattr(self, name, False) else: raise ForMoSAError(f" {name} must be a boolean or 'true'/'false' string, got '{value}'") elif not isinstance(value, bool): raise ForMoSAError(f" {name} must be a boolean, got {type(value)}") # Float fields float_fields = ("evidence_tolerance", "sampling_efficiency", "null_log_evidence", "mode_tolerance", "log_zero") for name in float_fields: value = getattr(self, name) if isinstance(value, str): try: setattr(self, name, float(value)) except ValueError: if value == 'None': setattr(self, name, None) else: raise ForMoSAError(f" {name} must be float or string convertible to float, got '{value}'") elif not isinstance(value, (float, int)) and value is not None: raise ForMoSAError(f" {name} must be float, got {type(value)}") # Integer fields int_fields = ("n_iter_before_update", "max_modes", "seed", "context", "max_iter", "wrapped_params") for name in int_fields: value = getattr(self, name) if isinstance(value, str): try: setattr(self, name, int(value)) except ValueError: if value == 'None': setattr(self, name, None) else: raise ForMoSAError(f" {name} must be int or string convertible to int, got '{value}'") elif (not (isinstance(value, int)) and (value is not None)): raise ForMoSAError(f" {name} must be int, got {type(value)}") # Optional callback is not serializable from ini files. # Convert placeholder strings to None and reject non-callable values. if isinstance(self.dump_callback, str): if self.dump_callback.strip().lower() in ('none', 'na', ''): self.dump_callback = None else: raise ForMoSAError( f" dump_callback must be None when read from config files, got '{self.dump_callback}'" ) elif self.dump_callback is not None and not callable(self.dump_callback): raise ForMoSAError(f" dump_callback must be callable or None, got {type(self.dump_callback)}")
# ======================= # Properties # ======================= @property def to_dict(self) -> dict: """Dictionary representation of ConfigPyMultiNest.""" return self.__dict__ # ======================= # Methods # =======================
[docs] @classmethod def from_dict(cls, data: dict) -> 'ConfigPyMultiNest': ''' Build an instance of ConfigPyMultiNest from dictionary of PyMultiNest parameters. Parameters ---------- data : dict Dictionary of PyMultiNest parameters Returns ------- 'ConfigPyMultiNest' An instance of class ConfigPyMultiNest Notes ----- Authors: Allan Denis ''' if not isinstance(data, dict): raise ForMoSAError(f'Wrong type for data: {type(data)}. Expected a dictionary') return cls(**data['pymultinest'])
[docs] def set_params(self, **kwargs) -> None: ''' Update parameters of ConfigPyMultiNest Parameters ---------- **kwargs : dict Keyword arguments to override attributes of the config Notes ----- Authors: Allan Denis ''' for key, value in kwargs.items(): if not hasattr(self, key): raise ForMoSAError(f'Unknown ConfigPath key: {key}') setattr(self, key, value) # Validation of the format self.__post_init__()
[docs] @dataclass class ConfigUltraNest: # Arguments for ReactiveNestedSampler wrapped_params: bool = field(default_factory=lambda: None) vectorized: bool = field(default_factory=lambda: False) resume: bool = field(default_factory=lambda: True) run_num: int = field(default_factory=lambda: None) num_bootstraps: int = field(default_factory=lambda: 30) storage_backend: str = field(default_factory=lambda: "hdf5") warmstart_max_tau: int = field(default_factory=lambda: -1) # Arguments for run dlogz: float = field(default_factory=lambda: 0.5) max_iters: int = field(default_factory=lambda: None) max_ncalls: int = field(default_factory=lambda: None) min_ess: int = field(default_factory=lambda: 400) frac_remain: float = field(default_factory=lambda: 0.01) cluster_num_live_points: int = field(default_factory=lambda: 40) Lepsilon: float = field(default_factory=lambda: 0.001)
[docs] def __post_init__(self) -> None: ''' Check UltraNest configuration parameters and normalize types. Notes ----- Authors: Allan Denis ''' # Float fields float_fields = ( "dlogz", "frac_remain", "Lepsilon", ) for name in float_fields: value = getattr(self, name) if isinstance(value, str): try: setattr(self, name, float(value)) except ValueError: if value == 'None': setattr(self, name, None) else: raise ForMoSAError(f" {name} must be float or string convertible to float, got '{value}'") elif not isinstance(value, (int, float)) and value is not None: raise ForMoSAError(f" {name} must be float, got {type(value)}") # Integer fields int_fields = ( "num_bootstraps", "warmstart_max_tau", "max_iters", "max_ncalls", "min_ess", "cluster_num_live_points", ) for name in int_fields: value = getattr(self, name) if isinstance(value, str): if value == 'None': setattr(self, name, None) else: try: setattr(self, name, int(value)) continue except ValueError: raise ForMoSAError(f"{name} must be int or string convertible to int, got '{value}'") elif not isinstance(value, int) and value is not None: raise ForMoSAError(f" {name} must be int, got {type(value)}") # Boolean fields bool_fields = ("wrapped_params", "vectorized") for name in bool_fields: value = getattr(self, name) if isinstance(value, str): if value == 'None': setattr(self, name, None) else: val_lower = value.lower() if val_lower == "true": setattr(self, name, True) elif val_lower == "false": setattr(self, name, False) else: raise ForMoSAError(f" {name} must be boolean or 'true'/'false' string, got '{value}'") elif not isinstance(value, bool) and value is not None: raise ForMoSAError(f" {name} must be boolean, got {type(value)}") if self.wrapped_params == 'None': self.wrapped_params = None float_fields = ("dlogz", "Lepsilon", "frac_remain") for name in float_fields: value = getattr(self, name) if isinstance(value, str): if value == 'None': setattr(self, name, None) else: try: setattr(self, name, float(value)) continue except ValueError: raise ForMoSAError(f"{name} must be a float or string convertible to float, got '{value}'") elif not isinstance(value, float) and value is not None: raise ForMoSAError(f"{name} must be float, got {type(value)}")
# ======================= # Properties # ======================= @property def ReactiveNSParams(self) -> dict: return { 'wrapped_params': self.wrapped_params, 'vectorized': self.vectorized, 'resume': self.resume, 'run_num': self.run_num, 'num_bootstraps': self.num_bootstraps, 'storage_backend': self.storage_backend, 'warmstart_max_tau': self.warmstart_max_tau } @property def runNSParams(self) -> dict: return { 'dlogz': self.dlogz, 'max_iters': self.max_iters, 'max_ncalls': self.max_ncalls, 'min_ess': self.min_ess, 'frac_remain': self.frac_remain, 'cluster_num_live_points': self.cluster_num_live_points, 'Lepsilon': self.Lepsilon } @property def to_dict(self) -> dict: """Dictionary representation of ConfigUltranest.""" return {'ReactiveNS': self.ReactiveNSParams, 'runNS': self.runNSParams} # ======================= # Methods # =======================
[docs] @classmethod def from_dict(cls, data: dict) -> 'ConfigUltraNest': ''' Build an instance of ConfigUltraNest from dictionary of UltraNest parameters. Parameters ---------- data : dict Dictionary of UltraNest parameters Returns ------- 'ConfigUltraNest' An instance of class ConfigUltraNest Notes ----- Authors: Allan Denis ''' if not isinstance(data, dict): raise ForMoSAError(f'Wrong type for data: {type(data)}. Expected a dictionary') return cls(**data['ultranest']['ReactiveNS'], **data['ultranest']['runNS'])
[docs] def set_params(self, **kwargs) -> None: ''' Update parameters of ConfigUltraNest Parameters ---------- **kwargs : dict Keyword arguments to override attributes of the config Notes ----- Authors: Allan Denis ''' for key, value in kwargs.items(): if not hasattr(self, key): raise ForMoSAError(f'Unknown ConfigPath key: {key}') setattr(self, key, value) # Validation of the format self.__post_init__()
[docs] @dataclass class Config_NS: nestle: ConfigNestle = field(default_factory=lambda: ConfigNestle()) pymultinest: ConfigPyMultiNest = field(default_factory=lambda:ConfigPyMultiNest()) ultranest: ConfigUltraNest = field(default_factory=lambda:ConfigUltraNest())
[docs] def __post_init__(self) -> None: ''' Check Config_NS configuration parameters. Notes ----- Authors: Allan Denis ''' # Check nestle type if not isinstance(self.nestle, ConfigNestle): raise ForMoSAError(f" nestle must be a ConfigNestle, got {type(self.nestle)}") # Check pymultinest type if not isinstance(self.pymultinest, ConfigPyMultiNest): raise ForMoSAError(f" pymultinest must be a ConfigPyMultinest, got {type(self.pymultinest)}") # Check ultranest type if not isinstance(self.ultranest, ConfigUltraNest): raise ForMoSAError(f" nestle must be a ConfigUltranest, got {type(self.ultranest)}")
# ======================= # Properties # ======================= @property def to_dict(self) -> dict: return { 'nestle': self.nestle.to_dict, 'pymultinest': self.pymultinest.to_dict, 'ultranest': self.ultranest.to_dict } # ======================= # Methods # =======================
[docs] @classmethod def from_dict(cls, data: dict) -> "Config_NS": ''' Build an instance of Config_NS from a dictionary of data. Parameters ---------- data : dict Dictionary representation of data Returns ------- "Config_NS" An instance of class Config_NS Notes ----- Authors: Allan Denis ''' if not isinstance(data, dict): raise ForMoSAError(f'Wrong type for data: {type(data)}. Expected a dictionary') return cls( nestle = ConfigNestle.from_dict(data), pymultinest = ConfigPyMultiNest.from_dict(data), ultranest=ConfigUltraNest.from_dict(data) )
[docs] def set_params(self, **kwargs) -> None: ''' Update parameters of Config_NS Parameters ---------- **kwargs : dict Keyword arguments to override attributes of the config Notes ----- Authors: Allan Denis ''' for key, value in kwargs.items(): if not hasattr(self, key): raise ForMoSAError(f'Unknown ConfigPath key: {key}') setattr(self, key, value) # Validation of the format self.__post_init__()
# ---------------------------- # Config file generator # ----------------------------
[docs] class ConfigGenerator: ''' Config file generator. Parameters ---------- sections : dict Dictionary containing sections of the config file logger : logging.Logger Logger log_level : str Level of the Logger Notes ----- Authors: Mathieu Ravet and Allan Denis ''' def __init__(self, sections: dict = None, logger: logging.Logger | None = None, log_level: str = 'INFO') -> None: self.logger = logger if logger is not None else setup_logging(level=log_level, name='ConfigGenerator') if sections is not None: # If we provide the sections self.config = sections else: # Otherwise, we generate default sections self.config = { "config_path": ConfigPath('unknown', 'unknown', 'unknown', 'unknown'), "config_adapt": ConfigAdapt(), "config_inversion": ConfigInversion(), "config_parameters": ConfigParameters(), "config_nestle": ConfigNestle(), "config_pymultinest": ConfigPyMultiNest(), "config_ultranest": ConfigUltraNest(), } self.comments = self._init_comments() def __repr__(self): return "ConfigGenerator()" def _init_comments(self): c = {} # ---------------- config_path ---------------- c["config_path"] = { "observation_path": [" # Path to the observed spectrum file"], "adapt_store_path": [" # Path to store your interpolated grid"], "result_path": [" # Path to store your results"], "model_path": [" # Path to the model"] } # ---------------- config_adapt ---------------- c["config_adapt"] = { "method": [ " # Adaptation method. /!\\ For safety reasons, this will also be the interpolation method", " # Format : 'linear' or 'nearest' or 'zero' or 'slinear' or 'quadratic' or 'cubic' or 'quintic' or 'pchip' or 'barycentric' or 'krogh' or 'akima' or 'makima'", " # MOSAIC : No" ], "emulator": [ " # If you want to use an emulator to fit your grid (smooth out the grid).", " # Format : 'NA' or 'PCA, ncomp' or 'NMF, ncomp'", " # MOSAIC : No" ], "target_res_obs": [ " # Target resolution to reach for the observation(s).", " # Format : float or 'obs' (if you want to keep the original obs resolution)", " # MOSAIC : Yes" ], "target_res_mod": [ " # Target resolution to reach for the model.", " # Format : float or 'obs' or 'mod' (if you want to keep the model's resolution during inversion)", " # MOSAIC : Yes" ], "wav_cont": [ " # Wavelength range used to estimate the continuum.", " # Format : 'NA' or 'window1_min / window1_max, window2_min / ... / windowN_max'", " # MOSAIC : Yes" ], "res_cont": [ " # Resolution used to estimate the continuum.", " # Format : 'NA' or float", " # MOSAIC : Yes" ], "backend": [ " # Joblib parallel backend used during grid adaptation.", " # Built-in options : 'loky' (default, true multiprocessing), 'multiprocessing', 'threading', 'sequential' (no parallelism)", " # Third-party options (require extra install) : 'dask', 'ray'", " # MOSAIC : No" ], "n_jobs": [ " # Number of parallel jobs used during grid adaptation.", " # Format : -1 (use all available CPUs) or a positive integer (e.g. 4)", " # MOSAIC : No" ] } # ---------------- config_inversion ---------------- c["config_inversion"] = { "logL_type": [ " # Method to calculate the loglikelihood function used in the nested sampling procedure.", " # Format : 'chi2' or 'chi2_covariance' or 'chi2_noisescaling' or 'chi2_noisescaling_covariance' or 'CCF_Brogi' or 'CCF_Zucker' or 'CCF_custom'", " # MOSAIC : Yes" ], "wav_fit": [ " # Wavelength range(s) used during the nested sampling procedure.", " # Format : 'window1_min / window1_max, window2_min / ... / windowN_max'", " # MOSAIC : Yes" ], "ns_algo": [ " # Nested sampling algorithm used.", " # Format : 'nestle' or 'pymultinest' or 'ultranest'", " # MOSAIC : No" ], "npoints": [ " # Number of living points during the nested sampling procedure.", " # Format : int", " # MOSAIC : No" ], "hc_lower_bounds_lsq": [ " # Least-square bounds.", " # Format : 'NA' or 'lower, upper'", " # MOSAIC : Yes" ], "hc_higher_bounds_lsq": [ " # Least-square bounds.", " # Format : 'NA' or 'lower, upper'", " # MOSAIC : Yes" ] } # ---------------- config_parameters ---------------- c["config_parameters"] = { "par1": [ " # Definition of the prior function of each parameter explored by the grid. Please refer to the documentation to check", " # the parameter space explore by each grid. Check prior functions for more infos", " # Format : 'function', function_param1, function_param2", " # MOSAIC : No" ], "par2": [], "par3": [], "par4": [], "r": [ " # Definition of the prior function of each extra-grid parameter. Check prior functions for more infos", " # Format : 'function', function_param1, function_param2", " # MOSAIC : Yes and No, check the doc !" ], "d": [] } # ---------------- config_nestle ---------------- c["config_nestle"] = { "method": [ " # Nestle configuration parameters. For more details, please see: http://kylebarbary.com/nestle/index.html", " # Format : _", " # MOSAIC : No" ], "update_interval": [], "npdim": [], "maxiter": [], "maxcall": [], "dlogz": [], "decline_factor": [], "rstate": [] } # ---------------- config_pymultinest ---------------- c["config_pymultinest"] = { "n_clustering_params": [ " # Pymultinest configuration parameters. For more details, please see: https://github.com/JohannesBuchner/PyMultiNest/blob/master/pymultinest/run.py", " # Format : _", " # MOSAIC : No" ], "wrapped_params": [], "importance_nested_sampling": [], "multimodal": [], "const_efficiency_mode": [], "evidence_tolerance": [], "sampling_efficiency": [], "n_iter_before_update": [], "null_log_evidence": [], "max_modes": [], "mode_tolerance": [], "seed": [], "verbose": [], "resume": [], "context": [], "log_zero": [], "max_iter": [], "init_MPI": [], "dump_callback": [], "use_MPI": [] } # ---------------- config_ultranest ---------------- c["config_ultranest"] = { "resume": [ " # Ultranest configuration parameters. For more details, please see: https://johannesbuchner.github.io/UltraNest/readme.html", " # Format : _", " # MOSAIC : No" ] } # Other keys are added without comments for key in asdict(ConfigUltraNest()).keys(): if key not in c["config_ultranest"]: c["config_ultranest"][key] = [] return c
[docs] def save(self, path: str | os.PathLike, name: str = 'new_config.ini'): ''' Save the ConfigGenerator to a given path. Parameters ---------- path : str | os.PathLike Path where to save the config file name : str Name of the config file Notes ----- Authors: Allan Denis ''' self.logger.info(f' Save config to path {path}') path = Path(path) / name config = ConfigObj(indent_type=' ', list_values=True) for sec_name, sec_obj in self.config.items(): config[sec_name] = {} sec_dict = sec_obj.__dict__ for key, val in sec_dict.items(): config[sec_name][key] = val if sec_name in self.comments and key in self.comments[sec_name]: config[sec_name].comments[key] = self.comments[sec_name][key] config.comments[sec_name] = [""] + config.comments.get(sec_name, []) config.filename = path config.write()
[docs] class ConfigLoader: ''' Class handling the loading of a config file. Parameters ---------- path : str | os.PathLike Path to the config file logger : logging.Logger Logger log_level : str Level of the Logger Notes ----- Authors: Allan Denis ''' def __init__(self, path: str | os.PathLike, logger: logging.Logger | None = None, log_level: str = 'INFO') -> None: self.path = path self.config_ini = ConfigObj(self.path, list_values=True, encoding='utf-8', file_error=False) self.logger = logger if logger is not None else setup_logging(level=log_level, name='ConfigLoader') self.defaults = ConfigGenerator().config self.config = {}
[docs] def load(self): ''' Load all the sections of .ini file in dataclasses. Notes ----- Authors: Allan Denis ''' self.logger.debug(f'Load config file {self.path}') # Default config file self._fill_defaults() # mapping section name -> dataclass mapping = { "config_path": ConfigPath, "config_adapt": ConfigAdapt, "config_inversion": ConfigInversion, "config_parameters": ConfigParameters, "config_nestle": ConfigNestle, "config_pymultinest": ConfigPyMultiNest, "config_ultranest": ConfigUltraNest, } for section, cls in mapping.items(): if section in self.config_ini: data = {} for key, val in self.config_ini[section].items(): if key not in self.defaults[section].__dict__.keys(): continue data[key] = val if section == "config_parameters": params = ConfigParameters() for key, val in self.config_ini[section].items(): params._add_parameter(key, val) self.config[section] = params continue self.config[section] = cls(**data) else: self.config[section] = cls() # Default values if not present self.logger.info(' Config file loaded') return self.config
def _fill_defaults(self): ''' Add missing sections / keys using defaults without overwriting existing values. Notes ----- Authors: Allan Denis ''' for section, default_obj in self.defaults.items(): if section not in self.config_ini: self.config_ini[section] = {} default_dict = asdict(default_obj) for key, value in default_dict.items(): if key not in self.config_ini[section]: self.config_ini[section][key] = value