Source code for ForMoSA.observation.observation_loader

import os
import logging
from astropy.io import fits
import numpy as np
from collections.abc import Iterable, Mapping

from ForMoSA.core.errors import ForMoSAError
from ForMoSA.core.loggings import setup_logging
from ForMoSA.core.enums import ObservationKeys, WavelengthUnit
from ForMoSA.observation.observation_base import Observation
from ForMoSA.observation.observation_spectroscopy import SpectralObservation
from ForMoSA.observation.observation_photometry import PhotometryObservation
from ForMoSA.utils import misc

[docs] class ObservationLoader: ''' Class responsible for observation loading from various inputs format Notes ----- Authors: Allan Denis ''' @staticmethod def _attributes_to_dict(**kwargs) -> dict: ''' Mechanical conversion of attributes to a dictionary. kwargs: Keyword arguments Returns ------- dict Dictionnary representation of the attributes Notes ----- Authors: Allan Denis ''' return {k: np.asarray(v) for k, v in kwargs.items() if v is not None} @staticmethod def _normalize_keys(keys: Iterable[str]) -> dict[str, str]: ''' Normalize input keys to canonical ObservationKey names. Parameters ---------- keys : iterable Column names (FITS table) or dictionary keys. Returns ------- Mapping {canonical_key: actual_key_in_input} Examples -------- >>> normalize_keys(["wave", "flux", "err", "instrument"]) --> {'WAVELENGTH': 'wave', 'FLUX': 'flux', 'ERROR': 'err', 'INSTRUMENT': 'instrument'} Notes ----- Authors: Allan Denis ''' if not isinstance(keys, Iterable): raise ForMoSAError("keys must be an iterable of strings") normalized: dict[str, str] = {} # Uppercase mapping upper_keys = {k.upper(): k for k in keys} for obs_key in ObservationKeys: for alias in obs_key.aliases: if alias in upper_keys: normalized[obs_key.canonical] = upper_keys[alias] break return normalized @staticmethod def _from_fits(path: str | os.PathLike, logger: logging.Logger | None = None, log_level: str='INFO', **kwargs) -> Observation: ''' Create an Observation from a FITS file. Parameters ---------- path : str | os.PathLike Path of the data (Fits file) logger : logging.Logger Logger log_level : str Level of the Logger **kwargs : Additional arguments Returns ------- Observation Instance of class Observation Notes ----- Authors: Allan Denis ''' logger = logger or setup_logging(log_level, name='Observation loader') if not str(path).lower().endswith(".fits"): raise ForMoSAError(f'{path} is not a FITS file') with fits.open(path) as hdul: if len(hdul) < 2: raise ForMoSAError(f'{path} does not contain a data extension') logger.info(f' Loading Observation from FITS file: {path}') data = misc.from_recarray_to_dic(hdul[1].data) return ObservationLoader._from_mapping(data=data, logger= logger, **kwargs) @staticmethod def _from_data(data: Mapping[str, np.ndarray], logger: logging.Logger | None = None, log_level: str='INFO', **kwargs) -> Observation: ''' Create an Observation from an in-memory dictionary. Parameters ---------- data : Mapping[str, np.ndarray] Data logger : logging.Logger Logger log_level : str Level of the Logger **kwargs : Additional arguments Returns ------- Observation Instance of class Observation Notes ----- Authors: Allan Denis ''' logger = logger or setup_logging(log_level, name='Observation loader') logger.info(' Creating Observation from data') return ObservationLoader._from_mapping(data=data, logger=logger, **kwargs) @staticmethod def _from_attributes(**kwargs) -> Observation: ''' Create an Observation from a list of attributes. If non keyword arguments are passed, raise an error. Parameters ---------- **kwargs : Keyword attributes Returns ------- Observation Notes ----- Authors: Allan Denis ''' # Retrieve logger arguments logger, log_level = kwargs.pop('logger', None), kwargs.pop('log_level', 'INFO') logger = logger or setup_logging(log_level, name='Observation loader') logger.info(' Loading observation from attributes') # Dictionnary representation of the data data = ObservationLoader._attributes_to_dict(**kwargs) return ObservationLoader._from_mapping(data=data, logger=logger, **kwargs) @staticmethod def _from_mapping(data: Mapping[str, np.ndarray], logger: logging.Logger | None = None, log_level: str='INFO', **kwargs) -> Observation: ''' Core logic shared by FITS and in-memory creation. Parameters ---------- data : Mapping[str, np.ndarray] Data logger : logging.Logger Logger log_level : str Level of the Logger **kwargs : Additional arguments Returns ------- Observation instance of class Observation Notes ----- Authors: Allan Denis ''' logger = logger or setup_logging(log_level, name='Observation loader') if not isinstance(data, Mapping): raise ForMoSAError('data must be a mapping (dict-like)', logger) # -------------------------------- # Retrieve and normalize keys # -------------------------------- keys = data.keys() normalized = ObservationLoader._normalize_keys(keys) # -------------------------------- # Check required common keys # -------------------------------- missing = [key.canonical for key in ObservationKeys.required_common() if key.canonical not in normalized] if missing: raise ForMoSAError(f" Missing required observation keys: {', '.join(missing)}", logger) # -------------------------------- # Extract common data # -------------------------------- wave = data[normalized["WAVELENGTH"]] flux = data[normalized["FLUX"]] # -------------------------------- # Specific case for units # -------------------------------- if ObservationKeys.WAVELENGTH_UNIT.canonical not in normalized: native_unit = WavelengthUnit.MICROMETER logger.warning(f'Wavelength unit not specified for observation. Assuming {native_unit.unit}') else: unit_value = np.unique(np.asarray(data[normalized["WAVELENGTH_UNIT"]], dtype=str))[0] native_unit = WavelengthUnit[str(unit_value)] # -------------------------------- # Detect spectroscopic observation # -------------------------------- is_spectro = (ObservationKeys.RESOLUTION.canonical in normalized and np.any(np.array(data[normalized["RESOLUTION"]]) > 0)) # ============================= # Spectroscopic observation # ============================= if is_spectro: missing_spectro = ObservationKeys.validate_spectroscopic(set(normalized.keys())) if missing_spectro: raise ForMoSAError(f"Missing required observation keys: {', '.join(missing_spectro)}", logger) # Facility and instrument for canonical, key in zip([ObservationKeys.FACILITY.canonical, ObservationKeys.INSTRUMENT.canonical], ['FACILITY', 'INSTRUMENT']): if canonical not in set(normalized.keys()): logger.warning(f"Key {key} not in observation keys. Setting it to 'unknown'") if key == 'FACILITY': facility = np.array(['unknown'] * len(wave)) elif key == 'INSTRUMENT': ins = np.array(['unknown'] * len(wave)) else: if key == 'FACILITY': facility = np.asarray(data[normalized["FACILITY"]], dtype=str) elif key == 'INSTRUMENT': ins = np.asarray(data[normalized["INSTRUMENT"]], dtype=str) logger.info(f' Detected spectroscopic observation with instruments {np.unique(facility)}/{np.unique(ins)}') # Resolution res = data[normalized["RESOLUTION"]] # ---------------------------- # Optional inputs # ---------------------------- # Error / Covariance err = (data[normalized["ERROR"]] if ObservationKeys.ERROR.canonical in normalized else None) cov = (data[normalized["COVARIANCE"]] if ObservationKeys.COVARIANCE.canonical in normalized else None) if err is None: err = np.sqrt(np.diag(cov)) # Transmission transm = (data[normalized["TRANSMISSION"]] if ObservationKeys.TRANSMISSION.canonical in normalized else None) # Stellar flux star_flux = ObservationLoader._extract_vector_series(data, ObservationKeys.STAR_FLUX, **kwargs) # Systematics system = ObservationLoader._extract_vector_series(data, ObservationKeys.SYSTEMATICS) # Continuums flux_cont = (data[normalized['FLUX_CONT']] if ObservationKeys.FLUX_CONT.canonical in normalized else None) star_flux_cont = (data[normalized['STAR_FLUX_CONT']] if ObservationKeys.STAR_FLUX_CONT.canonical in normalized else None) # Wavelength and resolution for the continuum try: wave_cont = (str(data[normalized['WAVE_CONT']]) if ObservationKeys.WAVE_CONT.canonical in normalized else kwargs.get('wave_cont', None)) except TypeError: # wave_cont is None wave_cont = None try: res_cont = (float(data[normalized['RES_CONT']]) if ObservationKeys.RES_CONT.canonical in normalized else kwargs.get('res_cont', None)) except TypeError: # res_cont is None res_cont = None obs = SpectralObservation( wave=wave, flux=flux, err=err, res=res, native_unit=native_unit, facility=facility, instrument=ins, cov=cov, transm=transm, star_flux=star_flux, system=system, ) obs._flux_cont = flux_cont obs._star_flux_cont = star_flux_cont obs._res_cont = res_cont obs._wave_cont = wave_cont # ============================= # Photometric observation # ============================= else: missing_photo = ObservationKeys.validate_photometric(set(normalized.keys())) if missing_photo: raise ForMoSAError(f"Missing required observation keys: {', '.join(missing_photo)}", logger) # Facility, ins and filter_id facility = np.asarray(data[normalized["FACILITY"]], dtype=str) ins = np.asarray(data[normalized["INSTRUMENT"]], dtype=str) filter_id = np.asarray(data[normalized["FILTER_ID"]], dtype=str) logger.info(f' Detected photometric observation with filter {np.unique(filter_id)}') # Error err = data[normalized["ERROR"]] obs = PhotometryObservation( wave=wave, flux=flux, err=err, native_unit=native_unit, facility=facility, instrument=ins, filter_id=filter_id, ) return obs @staticmethod def _extract_vector_series(data: Mapping[str, np.ndarray], key: ObservationKeys, **kwargs) -> None | np.ndarray: ''' Extract columns like PREFIX1, PREFIX2, ... and stack them. Examples -------- >>> systematics_array = ObservationFactory._extract_vector_series(data, ObservationKeys.SYSTEMATICS) >>> np.ndarray([systematics1, systematics2, systematics3, ...]) Parameters ---------- data : Mapping[str, np.ndarray] Data key : ObservationKeys Key of the column we want to extract (ObservationKeys.STAR_FLUX, ObservationKeys.SYSTEMATICS) **kwargs : Additional arguments Returns ------- None | np.ndarray Stacked columns Notes ----- Authors: Allan Denis ''' # Get the aliases corresponding to the key aliases = tuple(alias.upper() for alias in key.aliases) # Get the keys matching the input ObservationKey in chronological order # Careful with STAR_FLUX_CONT keywords matched_keys = sorted(k for k in data.keys() if k.upper().startswith(aliases) and not k.upper().endswith('_CONT')) if not matched_keys: return None # Retrieve 2D data vectors = [np.atleast_2d(data[k]).reshape(len(data[k]), -1) for k in matched_keys] return np.concatenate(vectors, axis=1)