Source code for ForMoSA.observation.observation_set

import os
import json
import logging
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from matplotlib.figure import Figure
import matplotlib.gridspec as gridspec
from matplotlib.axes._axes import Axes
from matplotlib import colors as mcolors
from matplotlib.ticker import AutoMinorLocator

from ForMoSA.core.config import MAIN_PLOT
from ForMoSA.core.errors import ForMoSAError
from ForMoSA.core.enums import ObservationKeys
from ForMoSA.core.loggings import setup_logging
from ForMoSA.observation.observation_base import Observation
from ForMoSA.observation.observation_spectroscopy import SpectralObservation
from ForMoSA.observation.observation_photometry import PhotometryObservation

[docs] class ObservationSet(object): ''' Container for a set of Observation objects. Parameters ---------- logger : logging.Logger Logger log_level : str Level of the logging Notes ----- Authors: Allan Denis ''' def __init__(self, logger: logging.Logger | None = None, log_level: str = "INFO") -> None: self._logger = logger if logger is not None else setup_logging(level=log_level, name="ObservationSet") self._observations: list[Observation] = [] # ================================================== # Representation # ================================================== def __repr__(self) -> str: return f' ObservationSet : {self.n_observations} observations' def __format__(self) -> str: return self.__repr__() # ================================================== # Collection protocol # ================================================== def __len__(self) -> int: return len(self._observations) def __iter__(self): return iter(self._observations) def __getitem__(self, idx: int) -> Observation: return self._observations[idx] # ================================================== # Properties # ================================================== @property def is_empty(self) -> bool: """Whether ObservationSet is empty.""" return len(self.observations) == 0 @property def logger(self) -> logging.Logger: """Logger.""" return self._logger @property def observation_names(self) -> list[str]: """List of observation names.""" return [obs.name for obs in self.observations] @property def observations(self) -> list[Observation]: """List of observations.""" return self._observations @property def n_observations(self) -> int: """Number of observations.""" return len(self) @property def has_spectroscopy(self) -> bool: """Whether the observation set has spectroscopy.""" for obs in self.observations: if obs.is_spectroscopic: return True return False @property def has_photometry(self) -> bool: """Whether the observation set has photometry.""" for obs in self.observations: if obs.is_photometric: return True return False @property def has_high_contrast(self) -> bool: """Whether the observation set has high-contrast observations.""" for obs in self.observations: if obs.hc_mode: return True return False @property def spectral_observations(self) -> list[SpectralObservation]: """List of spectroscopic observations.""" return [obs for obs in self.observations if obs.is_spectroscopic] @property def photometry_observations(self) -> list[PhotometryObservation]: """List of photometric observations.""" return [obs for obs in self.observations if obs.is_photometric] @property def high_contrast_observations(self) -> list[Observation]: """List of high-contrast observation.""" return [obs for obs in self.observations if obs.hc_mode] @property def max_resolution(self) -> float | None: """Maximum resolution (None if no spectroscopic observation).""" specs = self.spectral_observations if not specs: return None return None if not self.spectral_observations else max(obs.max_resolution for obs in specs) @property def min_resolution(self) -> float | None: """Minimum resolution (None if no spectroscopic observation).""" specs = self.spectral_observations if not specs: return None return None if not self.spectral_observations else min(obs.min_resolution for obs in specs) @property def wavelength_range(self) -> tuple[float, float]: """Global wavelength range.""" wmins = [obs.wavelength_range[0] for obs in self._observations] wmaxs = [obs.wavelength_range[1] for obs in self._observations] return min(wmins), max(wmaxs) @property def to_dict(self) -> dict: """Dictionary representation of the set of observations.""" data = {} for i, name in enumerate(self.observation_names): data[name] = self.observations[i].to_dict return data @property def mcolors_normalize(self) -> mcolors.Normalize: """Color normalization (for plotting).""" return plt.Normalize(vmin=self.wavelength_range[0], vmax=self.wavelength_range[1]) # ================================================== # Class methods # ==================================================
[docs] @classmethod def from_npz(cls, path: str | os.PathLike, logger: logging.Logger | None = None, log_level: str = 'INFO') -> "ObservationSet": ''' Create an instance of ObservationSet from a path containing observation fits files. Parameters ---------- path : str | os.PathLike Path containing all the observations logger : logging.Logger Logger log_level : str Level of the Logger Returns ------- "ObservationSet" An instance of ObservationSet Notes ----- Authors: Allan Denis ''' logger = logger if logger is not None else setup_logging(level=log_level, name='ObservationSet') logger.debug(f'Generating a set of observations from path {path}') # Initial checking if not isinstance(path, (str, os.PathLike)): raise ForMoSAError(f' Wrong type for path: {type(path)}. Expected a str or os.PathLike', logger) obs_set = cls(logger=logger) obs_path = Path(path).expanduser() / 'Observations' # Initial checks if not obs_path.exists(): raise ForMoSAError(f'{obs_path} does not exist', logger) obs_files = [obs_file for obs_file in os.listdir(obs_path) if obs_file.endswith('.npz')] if len(obs_files) == 0: raise ForMoSAError(f'Wrong path extension for files: {obs_files}. Require a .npz') # Generate ordered observations order_file = obs_path / "observation_order.json" if order_file.exists(): with open(order_file, "r") as f: ordered_names = json.load(f) obs_files = [f"Observation_{name}.npz" for name in ordered_names] else: logger.warning("No order_file found. The extracted observation order likely won't the initial observation order") for obs_file in obs_files: file_path = obs_path / obs_file obs = Observation.from_file(file_path, logger=logger) obs_set.add_observation(obs) logger.info(' Set of Observations generated') return obs_set
[docs] @classmethod def from_fits(cls, path: list[str | os.PathLike], logger: logging.Logger | None = None, log_level: str = 'INFO') -> "ObservationSet": ''' Create an instance of ObservationSet from a path containing observation fits files. Parameters ---------- path : list[str | os.PathLike Paths to the observations logger : logging.Logger Logger log_level : str Level of the Logger Returns ------- "ObservationSet" An instance of ObservationSet Notes ----- Authors: Allan Denis ''' logger = logger if logger is not None else setup_logging(level=log_level, name='ObservationSet') logger.debug(f'Generating a set of observations from path {path}') # Initial checking if not isinstance(path, list): raise ForMoSAError(f'Wrong type for path: {type(path)}. Expected a list', logger) if not all(isinstance(obs_path, (str | os.PathLike)) for obs_path in path): raise ForMoSAError('path must be a list of str or os.PathLike') obs_set = cls(logger=logger) obs_files: list[Path] = [] # Iterate over provided paths for p in path: p = Path(p).expanduser() if not p.exists(): raise ForMoSAError(f'{p} does not exist', logger) # Case 1: it's a file if p.is_file(): if p.suffix.lower() != '.fits': raise ForMoSAError(f'{p} is not a .fits file', logger) obs_files.append(p) # Case 2: it's a directory elif p.is_dir(): fits_in_dir = list(p.glob('*.fits')) if len(fits_in_dir) == 0: logger.warning(f'No .fits files found in {p}') obs_files.extend(fits_in_dir) if len(obs_files) == 0: raise ForMoSAError('No .fits files found in provided paths', logger) # Load observations for file_path in obs_files: obs = Observation.from_file(file_path, logger=logger) obs_set.add_observation(obs) logger.info(' Set of Observations generated') return obs_set
[docs] @classmethod def from_dict(cls, data: dict, logger: logging.Logger | None = None, log_level: str = 'INFO') -> 'ObservationSet': ''' Reconstruct an ObservationSet from a dictionary of ObservationSet. Parameters ---------- data : dict Dictionary containing ObservationSet parameters logger : logging.Logger Logger log_level : str Level of the logging Returns ------- 'ParameterSet' An instance of class ParameterSet Notes ----- Authors: Allan Denis ''' logger = logger if logger is not None else setup_logging(level=log_level, name='ParameterSet') if not isinstance(data, dict): raise ForMoSAError(f'Wrong type for data: {type(data)}. Expected a dictionary', logger) obs_set = cls(logger=logger) logger.debug('Build instance of ObservationSet from dictionary') for name in data.keys(): obs = Observation.from_dict(data=data[name], logger=logger) obs_set.add_observation(obs) return obs_set
[docs] @classmethod def from_json(cls, path: str | os.PathLike, logger: logging.Logger | None = None, log_level: str = 'INFO') -> 'ObservationSet': ''' Reconstruct an ObservationSet from a json file. Parameters ---------- path : str | os.PathLike Path to the json file logger : logging.Logger Logger log_level : str Level of the logging Returns ------- 'ParameterSet' An instance of class ParameterSet Notes ----- Authors: Allan Denis ''' logger = logger if logger is not None else setup_logging(level=log_level, name='ParameterSet') if not isinstance(path, (str, os.PathLike)): raise ForMoSAError(f'Wrong type for path: {type(path)}. Expected a string or os.PathLike', logger) logger.debug(f'Building instance of ObservationSet from json file {str(path) + "observations.json"}') filepath = Path(str(path) + 'observations.json') if not filepath.exists(): raise ForMoSAError(f'{filepath} does not exist') with open(filepath, "r") as f: data = json.load(f) return cls.from_dict(data, logger=logger)
[docs] @classmethod def from_list(cls, obs_list: list[Observation], logger: logging.Logger | None = None, log_level: str = 'INFO') -> 'ObservationSet': ''' Reconstruct an ObservationSet from a list of observations. Parameters ---------- obs_list : list[Observation] List containing observations logger : logging.Logger Logger log_level : str Level of the logging Returns ------- 'ParameterSet' An instance of class ParameterSet Notes ----- Authors: Allan Denis ''' logger = logger if logger is not None else setup_logging(level=log_level, name='ParameterSet') if not isinstance(obs_list, list): raise ForMoSAError(f'Wrong type for obs_list: {type(obs_list)}. Expected a list', logger) if not all(isinstance(obs, Observation) for obs in obs_list): raise ForMoSAError('Expected a list of observations for obs_list', logger) obs_set = cls(logger=logger) logger.debug('Build instance of ObservationSet from dictionary') for obs in obs_list: obs_set.add_observation(obs) return obs_set
# ================================================== # Methods # ==================================================
[docs] def add_observation(self, *args, **kwargs): ''' Add an observation to the set based on the type of data provided. Parameters ---------- args : - If a Observation object is provided, directly add the observation - If a `.fits` file is provided, provide a single argument `path` (str | Path) - If a dictionary of data is provided, provide a single argument `data` (dict) - If attributes are provided, provide the necessary arguments to create the observation (Spectral or Photometric) kwargs : Additional attributes for the observations if necessary. Example: - self.add_observation(path="path/to/file.fits") - self.add_observation(data={"wavelength": ..., "flux": ...}) - self.add_observation(name="spectral_obs", wavelength=..., flux=..., ...) Notes ----- Authors: Allan Denis ''' if len(args) == 1: if isinstance(args[0], (Observation, SpectralObservation, PhotometryObservation)): # If the argument is an Observation obs = args[0] elif isinstance(args[0], (str, os.PathLike)): # If the argument is a path (FITS file) obs = Observation.from_file(args[0], logger=self.logger, **kwargs) elif isinstance(args[0], dict): # If the argument is a dictionary of data obs = Observation.from_dict(args[0], logger=self.logger, **kwargs) else: raise ForMoSAError(f"Unrecognized input type {type(args[0])}", self.logger) elif len(kwargs) > 1: # If multiple arguments are provided, we assume they are attributes obs = Observation.from_attributes(**kwargs) else: raise ForMoSAError('No valid data provided to add an observation', self.logger) self.logger.info(f' Adding {obs.ObsType} Observation with name {obs.name} to the set of observations') self._observations.append(obs)
[docs] def save_all(self, path: str | os.PathLike, to_json: bool = False) -> None: ''' Save all observations to disk as .npz files. Parameters ---------- path : str | os.PathLike Directory where to save the observations prefix : str Prefix for the saved files to_json : bool Whether to save all observations in a json file Notes ----- Authors: Allan Denis ''' path = Path(path).expanduser() / 'Observations' self.logger.info(f' Saving all the observations {self.observation_names} to path {path}') if not path.exists(): self.logger.warning(f'path {path} does not exist. Creating it') path.mkdir(exist_ok=True, parents=True) if to_json is True: self.to_json(path) else: for obs in self.observations: obs.save_observation(path) # Save order order_file = path / "observation_order.json" with open(order_file, "w") as f: json.dump(self.observation_names, f, indent=4)
[docs] def adapt_all(self, target_resolution: list[np.ndarray], wave_cont: list[str] | None = None, res_cont: list[float] | None = None) -> None: ''' Adapt all observations to the target resolution. Parameters ---------- target_resolution: (list[np.ndarray]): List of target resolution to reach for the observations wave_cont : list[str] List of wavelengths used for the continuum res_cont : list[float] List os resolutions used for the continuum Notes ----- Authors: Simon Petrus, Matthieu Ravet and Allan Denis ''' # Initial checks if not isinstance(target_resolution, list): raise ForMoSAError(f' Wrong type for target_resolution: {type(target_resolution)}. Expected a list', self.logger) if len(target_resolution) != self.n_observations: raise ForMoSAError(f' Wrong length for target_resolution: {len(target_resolution)}. Expected {self.n_observations}', self.logger) if wave_cont is None: wave_cont = [None] * self.n_observations elif not isinstance(wave_cont, list): raise ForMoSAError(f' Wrong type for wave_cont: {type(wave_cont)}. Expected a list or None', self.logger) if res_cont is None: res_cont = [None] * self.n_observations elif not isinstance(res_cont, list): raise ForMoSAError(f' Wrong type for res_cont: {type(res_cont)}. Expected a list or None', self.logger) elif len(res_cont) != self.n_observations: raise ForMoSAError(f' Wrong length for res_cont: {len(res_cont)}. Expected {self.n_observations}', self.logger) self.logger.debug(f' Adapting all the observations {self.observation_names}') # Adaptation to observations for i, obs in enumerate(self._observations): self._logger.info(f' Adapting Observation: {obs.name}') self.observations[i] = obs._adapt_to_resolution(target_resolution[i], wave_cont=wave_cont[i], res_cont=res_cont[i]) self.logger.info(f' Observations {self.observation_names} adapted')
[docs] def to_json(self, path: str | os.PathLike) -> None: ''' Save the set of observations to a given path as a json file. Parameters ---------- path : str | os.PathLike Path to save the set of parameters Notes ----- Authors: Allan Denis ''' if not isinstance(path, (str, os.PathLike)): raise ForMoSAError(f'Wrong type for path: {type(path)}. Expected a string or os.PathLike', self.logger) self.logger.info(f' Saving set of observations to json path {Path(path) / "observations.json"}') path = Path(path) if not path.exists(): self.logger.warning(f'{path} does not exist. Creating it.') path.mkdir(exist_ok=True, parents=True) with open(path / 'observations.json', 'w') as f: json.dump(self.to_dict, f, indent=4)
[docs] def plot_all(self, fig: Figure | None = None, ax: Axes | None = None, ax_hc: Axes | None = None, ax_filt: Axes | None = None) -> tuple[Figure, Axes, Axes | None]: ''' Plot all the observations and photometric filters. Parameters ---------- fig : matplotlib.figure.Figure Figure (used to overplot on an existing figure) ax : matplotlib.axes._axes.Axes Ax (used to overplot non-high-contrast observations) ax_hc : matplotlib.axes._axes.Axes Ax used to overplot high-contrast observations. When None and ax is provided, HC observations fall back to ax (e.g. inside plot_fit). ax_filt : matplotlib.axes._axes.Axes Ax used to overplot the transmission filter Returns ------- fig : matplotlib.figure.Figure New Figure object ax : matplotlib.axes._axes.Axes New Ax object ax_filt : matplotlib.axes._axes.Axes New ax object for photometric filters Notes ----- Authors: Allan Denis ''' self.logger.info(f' Plotting all the observations {self.observation_names}') main_plot_config = MAIN_PLOT # Create figure if not provided if fig is None: fig = plt.figure(figsize=main_plot_config.figsize) # Create axes for observations if not provided if ax is None and ax_hc is None: # Create a gridspec to have more control over the layout of the axes gs = gridspec.GridSpec(9, 10) # If we have both high-contrast and non-high-contrast observations, we create two separate axes for them if self.has_high_contrast and len(self.high_contrast_observations) != self.n_observations: ax = fig.add_subplot(gs[2:6, 0:10]) ax_hc = fig.add_subplot(gs[6:9, 0:10], sharex=ax) # If we only have high-contrast observations, we use the whole space for the high-contrast axis else: if self.has_high_contrast: ax_hc = fig.add_subplot(gs[2:9, 0:10]) else: ax = fig.add_subplot(gs[2:9, 0:10]) # Create photometric filter axis only if not provided if self.has_photometry and ax_filt is None: gs = gridspec.GridSpec(9, 10) ax_filt = fig.add_subplot(gs[0:2, 0:10], sharex=(ax if ax is not None else ax_hc)) # Plot each observation — legend is suppressed here; # plot_all renders a single consolidated legend below. for obs in self.observations: if not obs.hc_mode: fig, ax, ax_filt = obs.plot_data(fig=fig, ax=ax, ax_filt=ax_filt, draw_legend=False) else: # When ax_hc is not set (e.g. called with a pre-existing ax from # plot_fit), fall back to ax so data lands on the main axes. _target_hc = ax_hc if ax_hc is not None else ax fig, _target_hc, ax_filt = obs.plot_data(fig=fig, ax=_target_hc, ax_filt=ax_filt, draw_legend=False) if ax_hc is not None: ax_hc = _target_hc else: ax = _target_hc # Use whichever axis was effectively used for spectral plotting. plot_axis = ax if ax is not None else ax_hc if plot_axis is None: raise ForMoSAError('No plotting axis available for observations', self.logger) if self.has_high_contrast and len(self.high_contrast_observations) == self.n_observations: ncol = max(1, int(main_plot_config.legend_hc_ncol)) else: ncol = max(1, int(main_plot_config.legend_ncol)) handles, labels = plot_axis.get_legend_handles_labels() if handles: plot_axis.legend(ncol=ncol, frameon=False, loc='upper right', fontsize=main_plot_config.legend_fontsize) # Add legend for photometric filters if we have photometry and an axis for the filters if ax_filt is not None: filt_handles, filt_labels = ax_filt.get_legend_handles_labels() if filt_handles: ax_filt.legend(ncol=max(1, int(main_plot_config.legend_filt_ncol)), frameon=False) # Minor ticks if main_plot_config.minor_ticks: # Principal axis plot_axis.xaxis.set_minor_locator(AutoMinorLocator(main_plot_config.nb_minor_ticks)) plot_axis.yaxis.set_minor_locator(AutoMinorLocator(main_plot_config.nb_minor_ticks)) if ax_filt is not None: # Filter axis ax_filt.xaxis.set_minor_locator(AutoMinorLocator(main_plot_config.nb_minor_ticks)) ax_filt.yaxis.set_minor_locator(AutoMinorLocator(main_plot_config.nb_minor_ticks)) # Rescale y axis with a power of 10 ymin, ymax = plot_axis.get_ylim() ymax_abs = max(abs(ymin), abs(ymax)) if ymax_abs > 0: exponent = int(np.floor(np.log10(ymax_abs))) plot_axis.yaxis.set_major_formatter(ticker.FuncFormatter(lambda y, pos: f"{y/10**exponent:.1f}")) plot_axis.set_ylabel(rf'Flux ($10^{{{exponent}}}$ W.m$^{{-2}}$.$\mu$m$^{{-1}}$)') else: plot_axis.set_ylabel(r'Flux (W.m$^{-2}$.$\mu$m$^{-1}$)') return fig, plot_axis, ax_filt
def _stack(self, ind_obs: list[int] | None = None, print_logger: bool = False) -> dict: ''' Stack all observations using the dictionary representation. Parameters ---------- ind_obs : list[int] List of index of observations to stack. If None, stack all observations Returns ------- dict Stacked observations sorted by wavelength Notes ----- Authors: Allan Denis ''' if print_logger: self.logger.info(" Stacking observations") wavelength_all = [] flux_all = [] error_all = [] res_all = [] if ind_obs is None: ind_obs = np.arange(self.n_observations) for iobs in ind_obs: obs_data = self.observations[iobs] wave = np.atleast_1d(obs_data.wave) flux = np.atleast_1d(obs_data.flux) err = np.atleast_1d(obs_data.err) res = np.atleast_1d(obs_data.res) wavelength_all.append(wave) flux_all.append(flux) error_all.append(err) res_all.append(res) # concatenate wavelength_all = np.concatenate(wavelength_all) flux_all = np.concatenate(flux_all) error_all = np.concatenate(error_all) res_all = np.concatenate(res_all) # sort by wavelength idx = np.argsort(wavelength_all) stacked = { ObservationKeys.WAVELENGTH.canonical : wavelength_all[idx], ObservationKeys.FLUX.canonical: flux_all[idx], ObservationKeys.ERROR.canonical: error_all[idx], ObservationKeys.RESOLUTION.canonical: res_all[idx], } return stacked