import copy
import logging
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
import matplotlib.gridspec as gridspec
from matplotlib.axes._axes import Axes
from ForMoSA.core.errors import ForMoSAError
from ForMoSA.filter.filter import PhotometryFilter
from ForMoSA.observation.observation_base import Observation
from ForMoSA.core.config import PhotometricPlotConfig, MAIN_PLOT
from ForMoSA.core.enums import ObservationType, ObservationKeys, WavelengthUnit
[docs]
class PhotometryObservation(Observation):
'''
Photometric observation class.
Parameters
----------
wave : np.ndarray
Wavelength array
flux : np.ndarray
Flux array
err : np.ndarray
Error array
instrument : np.ndarray
Instrument
facility : np.ndarray
Facility
filter_id : np.ndarray
Filter ID
native_unit : WavelengthUnit
native unit of the wavelength
logger : logging.Logger
Logger
log_level : str
Level of the logger
display_unit : WavelengthUnit
Unit of the wavelength to display
Notes
-----
Authors: Allan Denis
'''
def __init__(self, wave: np.ndarray, flux: np.ndarray, err: np.ndarray, instrument: np.ndarray, facility: np.ndarray, filter_id: np.ndarray, native_unit: WavelengthUnit, logger: logging.Logger | None = None, log_level: str = 'INFO', display_unit: WavelengthUnit = WavelengthUnit.MICROMETER) -> None:
self._filter_id = np.atleast_1d(np.asarray(filter_id, dtype=str))
# Inherit from Observation class
super().__init__(wave=wave, flux=flux, err=err, facility=facility, instrument=instrument, native_unit=native_unit, logger=logger, log_level=log_level, display_unit=display_unit, plot_config=PhotometricPlotConfig())
self._Filter = np.array([])
self._validate_photometry()
# ==================================================
# Representation
# ==================================================
def __repr__(self) -> str:
return f' PhotometricObservation : {self.name}'
def __format__(self) -> str:
return self.__repr__()
# ==================================================
# Properties
# ==================================================
@property
def ObsType(self) -> ObservationType:
"""Observation type."""
return ObservationType.PHOTOMETRIC.obstype
@property
def res(self) -> np.ndarray[float]:
"""Resolution."""
return np.array([0.0] * len(self.wave))
@property
def hc_mode(self) -> bool:
"""Whether observation is in high-contrast mode."""
return False
@property
def to_dict(self) -> dict[str, np.ndarray]:
"""Dictionary representation of photometric observations."""
return {
ObservationKeys.WAVELENGTH.canonical: self.wave.tolist(),
ObservationKeys.FLUX.canonical: self.flux.tolist(),
ObservationKeys.ERROR.canonical: self.err.tolist(),
ObservationKeys.FACILITY.canonical: self.facility,
ObservationKeys.INSTRUMENT.canonical: self.instrument,
ObservationKeys.FILTER_ID.canonical: self.filter_id,
ObservationKeys.WAVELENGTH_UNIT.canonical: str(WavelengthUnit[str(self.unit)].value),
}
@property
def Filter(self) -> np.ndarray[PhotometryFilter]:
"""Photometric filters."""
return self._Filter
@property
def filter_id(self) -> np.ndarray[str]:
"""Filter ID."""
return self._filter_id
@property
def name(self) -> str:
"""Observation name."""
# ---- Facilities
facilities = sorted(set(self.facility.astype(str)))
facility_str = f'[{"+".join(facilities)}]'
# ---- Instruments
instruments = sorted(set(self.instrument.astype(str)))
instrument_str = f'[{"+".join(instruments)}]'
# ---- Filters
filters = sorted(set(self.filter_id.astype(str)))
nfilters = len(filters)
# Condense if too many filters
if nfilters <= 6:
filter_str = f'[{"+".join(filters)}]'
else:
filter_str = f"[{nfilters}filters]"
return f"{facility_str}_{instrument_str}_{filter_str}"
@property
def wavelength_range(self) -> tuple:
"""Wavelength range of the observation."""
wmin = np.min([filt.wavelength_min for filt in self.Filter])
wmax = np.max([filt.wavelength_max for filt in self.Filter])
return wmin, wmax
@property
def filter_idxs(self) -> np.ndarray:
"""Indexes of occurence of new filters."""
idxs = np.array([0])
last_filt_id = self.filter_id[0]
for idx, filt_id in enumerate(self.filter_id):
if filt_id != last_filt_id:
idxs = np.append(idxs, idx)
last_filt_id = filt_id
idxs = np.append(idxs, len(self.filter_id))
return idxs
@property
def nb_filters(self) -> int:
"""Number of filters."""
return len(self.filter_idxs) - 1
# ==================================================
# Methods
# ==================================================
def _validate_photometry(self) -> None:
'''
Do some checks on photometric observations.
Notes
-----
Authors: Allan Denis
'''
if not len(self.filter_id) == len(self.instrument):
raise ForMoSAError('filter_id and instrument must have same lengths', self.logger)
for i, (filt_id, facility, instrument) in enumerate(zip(self.filter_id, self.facility, self.instrument)):
self._Filter = np.append(self._Filter, PhotometryFilter(self.facility[i], self.instrument[i], filt_id))
self._Filter[i]._set_unit(WavelengthUnit[str(self.unit)])
if (self.wave[0] < self.wavelength_range[0]) or (self.wave[0] > self.wavelength_range[1]):
raise ForMoSAError(f'Wrong value for wave: {self.wave}. Expected a value between {list(self.wavelength_range)}', self.logger)
def _adapt_to_resolution(self, target_resolution: float | None = None, wave_cont: str | None = None, res_cont: float | None = None) -> "PhotometryObservation":
'''
For photometry, this function does not implement anything.
Notes
-----
Authors: Allan Denis
'''
self.logger.info(f' Observation {self.name} is photometric. No adaptation')
return self
[docs]
def plot_data(self, fig: Figure | None = None, ax: Axes | None = None, ax_filt: Axes | None = None) -> tuple[Figure, Axes, Axes]:
'''
Plot photometric data.
Parameters
----------
figure : matplotlib.figure.Figure
Figure (used to overplot on an existing figure)
ax : matplotlib.axes._axes.Axes
Ax (used to overplot on an existing ax)
ax_filt : matplotlib.axes._axes.Axes
Ax used to overplot the transmission filter on an existing ax
Returns
-------
fig : matplotlib.figure.Figure
Updated figure
ax : matplotlib.axes._axes.Axes
Updated ax
ax_filt : matplotlib.axes._axes.Axes
Updated ax_filt
Notes
-----
Authors: Allan Denis
'''
self.logger.info(f' Plotting data for observation {self.name}')
plot_config = self.plot_config
main_plot_config = MAIN_PLOT
# --------------------------------------------------
# Figure / axes creation
# --------------------------------------------------
if ax is None or ax_filt is None:
fig = plt.figure(figsize=main_plot_config.figsize)
gs = gridspec.GridSpec(9, 10)
ax = fig.add_subplot(gs[3:9, 0:10])
ax_filt = fig.add_subplot(gs[0:3, 0:10], sharex=ax)
ax_filt.set_ylabel("Transmission")
ax.set_xlabel(f"Wavelength ({getattr(self, 'unit', '')})")
ax.set_ylabel("Flux")
# --------------------------------------------------
# Plot data
# --------------------------------------------------
filt_handles = []
filt_labels = []
for i, filt in enumerate(self.Filter):
idx0 = self.filter_idxs[i]
idx1 = self.filter_idxs[i + 1]
# ------------------------
# Transmission (TOP PANEL)
# ------------------------
fig, ax_filt = filt._plot_transmission_curve(fig=fig, ax=ax_filt, plot_config=plot_config, main_plot_config = main_plot_config)
# ------------------------
# Photometric points (MID PANEL)
# ------------------------
ax.scatter(
self.wave[idx0:idx1],
self.flux[idx0:idx1],
color=plot_config.color,
edgecolors=plot_config.edgecolor,
marker=plot_config.marker,
s=plot_config.markersize,
linewidths=plot_config.linewidth,
zorder=plot_config.zorder_data,
label=f'{filt.name}'
)
ax.errorbar(
self.wave[idx0:idx1],
self.flux[idx0:idx1],
yerr=self.err[idx0:idx1],
xerr=filt.width,
fmt=plot_config.errorbar_fmt,
ecolor=plot_config.color,
alpha=plot_config.errorbar_alpha,
capsize=plot_config.errorbar_capsize,
zorder=plot_config.zorder_error
)
# Get handles of ax_filt
lines = ax_filt.get_lines()
filt_handles.extend(lines)
filt_labels.append(f"{filt.name}")
# --------------------------------------------------
# Legend upper panel (Photometric data)
# --------------------------------------------------
if plot_config.label_filter:
ax_filt.legend(filt_handles, filt_labels, fontsize=main_plot_config.legend_fontsize, ncol=main_plot_config.legend_filt_ncol, frameon=False)
# --------------------------------------------------
# Legend mid panel (Photometric data)
# --------------------------------------------------
if plot_config.label_data:
ax.legend(fontsize=main_plot_config.legend_fontsize, ncol=main_plot_config.legend_ncol, frameon=False)
# --------------------------------------------------
# Axis labels
# --------------------------------------------------
ax.set_xlabel(f"Wavelength ({getattr(self, 'unit', '')})")
ax.set_ylabel("Flux")
return fig, ax, ax_filt
def _restricted_observation(self, windows: str | None = None, print_logger: bool = True) -> "PhotometryObservation":
'''
Restrict the observation to wavelength windows.
Parameters
----------
windows : str
Windows in the format 'wmin1,wmax1 / wmin2,wmax2 / ...'
print_logger : bool
Whether to print logger
Returns
-------
PhotometryObservation
Restricted observation
Notes
-----
Authors: Allan Denis
'''
# Dictionary of the observation
restricted = copy.deepcopy(self)
if windows is None:
windows = f'{self.wave[0]}, {self.wave[-1]}'
if print_logger:
self.logger.debug(f'Restricting observation {self.name} onto wavelengths windows {windows}')
ind = np.array([], dtype=int)
for window in windows.split("/"):
wmin, wmax = map(float, window.split(","))
indices = np.where((self.wave >= wmin) & (self.wave <= wmax))[0]
ind = np.concatenate((ind, indices))
ind = np.unique(ind)
for name, value in zip(['_wave', '_flux', '_err', '_Filter'], [self.wave, self.flux, self.err, self.Filter]):
if value is not None:
setattr(restricted, name, value[ind])
if print_logger:
self.logger.info(f' Wavelength of former Observation: {self.wavelength_range}. Wavelength of restricted obervation: {restricted.wavelength_range}')
return restricted