import copy
import corner
import logging
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
import matplotlib.gridspec as gridspec
from matplotlib.axes._axes import Axes
from ForMoSA.core.config import PLOTS_CONFIG
from ForMoSA.core.errors import ForMoSAError
from ForMoSA.core.loggings import setup_logging
from ForMoSA.transform.observed import ObservedModel
from ForMoSA.nested_sampling.results import NSResults
from ForMoSA.observation.observation_set import ObservationSet
[docs]
class Plotting(object):
'''
Class of visualisation of the results of the nested sampling.
Parameters
----------
results : NSResults
Instance of class NSResults
logger : Logger
Logger used
log_level : str
Level of the Logger
Notes
-----
Authors: Paulina Palma-Bifani, Matthieu Ravet and Allan Denis
'''
def __init__(self, results: NSResults, logger: logging.Logger, log_level: str = 'INFO') -> None:
self._logger = logger or setup_logging(log_level)
self._ns_results = results
if not isinstance(results, NSResults):
raise ForMoSAError(f'<Wrong type for results: {type(results)}. Expected a NSResults>', self.logger)
# =================
# Representation
# =================
def __repr__(self):
return '<Plotting>'
# =================
# Properties
# =================
@property
def logger(self) -> logging.Logger:
"""Logger."""
return self._logger
@property
def ns_results(self) -> NSResults:
"""Instance of classe NSResults."""
return self._ns_results
# =================
# Methods
# =================
[docs]
def plot_corner(self) -> Figure:
'''
Corner plot the posterior samples from the nested sampling results.
Parameters
----------
config : CornerPlotConfig
Instance of class CornerPlotConfig
Returns
-------
matplotlib.figure.Figure
Figure containin corner plots.
Notes
-----
Authors: Paulina Palma-Bifani, Matthieu Ravet and Allan Denis
'''
self._logger.info(' Plotting Corner plot')
samples, weights = self.ns_results.samples[self.ns_results.burn_in:], self.ns_results.weights[self.ns_results.burn_in:]
# Get config for Corner plot
config = PLOTS_CONFIG.CornerPlot
# Get corner arguments from the config
corner_kwargs = config.to_dict
corner_kwargs['labels'] = self.ns_results.free_parameters
corner_kwargs['weights'] = weights
corner_kwargs['range'] = [0.99999 for i in self.ns_results.free_parameters]
# Create the figure
fig = corner.corner(samples, **corner_kwargs)
return fig
[docs]
def plot_chains(self) -> tuple[Figure, Axes]:
'''
Plot the chains of the samples results.
Parameters
----------
Returns:
--------
tuple[matplotlib.figure.Figure, matplotlib.axes._axes.Axes]
Tuple containing Figure and Ax objects
Notes
-----
Authors: Paulina Palma-Bifani, Matthieu Ravet and Allan Denis
'''
self._logger.info(' Plotting posterior chains for each parameter.')
samples, weights = self.ns_results.samples, self.ns_results.weights
samples = self.ns_results.samples
param_best_values = list(self.ns_results.median_parameters.values())
n_params = samples.shape[1]
n_rows = (n_params + 1) // 2
# Get config for chains plot
config = PLOTS_CONFIG.ChainsPlot
fig, axs = plt.subplots(n_rows, 2, figsize=config.figsize)
axs = axs.flatten()
for idx in range(n_params):
ax = axs[idx]
param_name = self.ns_results.free_parameters[idx]
ax.plot(samples[:, idx], color=config.color_chains, alpha=config.alpha_chains)
ax.set_ylabel(param_name)
ax.axvline(self.ns_results.burn_in, linestyle=config.linestyle_burn_in, color=config.color_plot_burn_in)
ax.text(x = config.text_burn_in[0], y = config.text_burn_in[1], s='burn in', color=config.color_text_burn_in, transform=ax.transAxes, fontsize=config.fontsize_burn_in)
if config.show_weights:
ax_w = ax.twinx()
ax_w.plot(weights, config.color_plot_weights, alpha=config.alpha_weights)
ax_w.set_yticks([])
ax_w.text(x=config.text_weights[0], y=config.text_weights[1], s='weights', color=config.color_text_weights, transform=ax_w.transAxes, fontsize=config.fontsize_weights)
if config.plot_best_value:
ax.axhline(param_best_values[idx], color=config.color_best_value, linestyle=config.linestyle_best_value)
for idx in range(n_params, len(axs)):
fig.delaxes(axs[idx])
return fig, axs[:n_params]
[docs]
def plot_radars(self) -> tuple[Figure, Axes]:
'''
Radar plot the samples.
Parameters
----------
config : RadarPlotConfig
Instance of class RadarPlotConfig
Returns
-------
tuple[matplotlib.figure.Figure, matplotlib.axes._axes.Axes]
Tuple containing Figure and Ax objects
Notes
-----
Authors: Paulina Palma-Bifani, Matthieu Ravet and Allan Denis
'''
self._logger.info(' Plotting radar plot of the chains')
samples, weights = self.ns_results.samples[self.ns_results.burn_in:], self.ns_results.weights[self.ns_results.burn_in:]
samples = self.ns_results.samples
# Get config for radar plot
config = PLOTS_CONFIG.RadarPlot
# Compute quantiles for each parameter
q_low, q_med, q_high = [], [], []
for i in range(samples.shape[1]):
q_low.append(self.ns_results._weighted_quantile(samples[:,i], weights, config.quantiles[0]))
q_med.append(self.ns_results._weighted_quantile(samples[:,i], weights, 0.5))
q_high.append(self.ns_results._weighted_quantile(samples[:,i], weights, config.quantiles[1]))
q_low = np.array(q_low)
q_med = np.array(q_med)
q_high = np.array(q_high)
# Use min/max of samples to simulate prior bounds
prior_mins = np.min(samples, axis=0)
prior_maxs = np.max(samples, axis=0)
# Normalize based on "prior-like" range
q_low_norm, q_med_norm, q_high_norm = [], [], []
for i in range(len(q_low)):
min_val = prior_mins[i]
max_val = prior_maxs[i]
range_val = max_val - min_val if max_val != min_val else 1.0
q_low_norm.append((q_low[i] - min_val) / range_val)
q_med_norm.append((q_med[i] - min_val) / range_val)
q_high_norm.append((q_high[i] - min_val) / range_val)
# Close the circle
q_low_norm.append(q_low_norm[0])
q_med_norm.append(q_med_norm[0])
q_high_norm.append(q_high_norm[0])
q_med = np.append(q_med, q_med[0])
q_low = np.append(q_low, q_low[0])
q_high = np.append(q_high, q_high[0])
prior_mins = np.append(prior_mins, prior_mins[0])
prior_maxs = np.append(prior_maxs, prior_maxs[0])
# Angles for the radar plot
angles = np.linspace(0, 2 * np.pi, len(self.ns_results.free_parameters), endpoint=False).tolist()
angles.append(angles[0])
fig, ax = plt.subplots(figsize=(6, 6), subplot_kw=dict(polar=True))
ax.fill_between(angles, q_low_norm, q_high_norm, color=config.color_radar, alpha=config.alpha_fill)
ax.plot(angles, q_med_norm, color=config.color_radar, linewidth=2)
ax.scatter(angles[:-1], q_med_norm[:-1], color=config.color_quantiles, s=config.size_quantiles)
ax.set_xticks(angles[:-1])
ax.set_xticklabels(self.ns_results.free_parameters, fontsize=config.fontsize_names)
ax.set_yticklabels([])
# ax.set_title('Radar plot', size=14, pad=20)
ax.grid(True)
# Display ticks
for i, angle in enumerate(angles[:-1]):
min_val = prior_mins[i]
max_val = prior_maxs[i]
ticks = np.linspace(min_val, max_val, num=5)
range_val = max_val - min_val if max_val != min_val else 1.0
for i in range(len(ticks)-2):
radius = (ticks[i+1] - min_val) / range_val
ax.text(angle, radius, f'{ticks[i+1]:.2f}', ha='center', va='center', fontsize=config.fontisze_ticks, color=config.color_ticks)
return fig, ax
[docs]
def plot_fit(self, observations: ObservationSet, best_fit: list[ObservedModel], figsize: tuple=(12,7), plot_native_model: bool = False, native_model: ObservedModel | None = None) -> tuple[Figure, Axes, Axes, Axes, Axes]:
'''
Plot best fit
Parameters
----------
observations : ObservationSet
Instance of class ObservationSet
best_fit : list[ObservedModel]
List of instances of class ObservedModel corresponding to the best-fit model for each observation
figsize : tuple[float, float]
Size of the figure
plot_native_model : bool
Whether to plot the native model
native_model : ObservedModel
As instance of ObservedModel
Returns
-------
tuple[Figure, Axes, Axes, Axes, Axes]
Figure and ax objects
Notes
-----
Authors: Paulina Palma-Bifani, Matthieu Ravet and Allan Denis
'''
self._logger.info(' Plotting best fit and residuals')
# Initial checks
if not isinstance(best_fit, list) or len(best_fit) != observations.n_observations:
raise ForMoSAError(f'best_fit must be a list with {observations.n_observations}', self.logger)
if plot_native_model is True:
if not isinstance(native_model, ObservedModel):
raise ForMoSAError(f'If you want to plot the native model, native_model must be an instance of ObservedModel. Got {type(native_model)}')
# Get config for best fit
config = PLOTS_CONFIG.BestFitPlot
obs_set_transformed = ObservationSet(self.logger)
for i, obs in enumerate(observations.observations):
# Create a copy of the observations to optionally remove component estimated by high-contrast module
obs_transformed = copy.deepcopy(obs)
obs_transformed._flux -= best_fit[i].component
obs_set_transformed.add_observation(obs_transformed)
fig = plt.figure(figsize=figsize)
fig.clf()
gs = gridspec.GridSpec(9, 11)
# Main axis for observations + best-fit
ax = fig.add_subplot(gs[2:7, 0:10])
# Axis for photometric filters
ax_filt = None
if observations.has_photometry:
ax_filt = fig.add_subplot(gs[0:2, 0:10], sharex=ax)
# Residuals and histogram axes
axr = fig.add_subplot(gs[7:9, 0:10], sharex=ax)
axr2 = fig.add_subplot(gs[7:9, 10:11], sharey=axr)
# Plot native model if required
if plot_native_model:
ax.plot(native_model.wave, native_model.flux, color=config.color, linewidth=config.linewidth, zorder=config.zorder)
# concatenate all residuals first
all_residuals = []
for i, obs in enumerate(observations.observations):
res = best_fit[i].residuals(obs.flux)
all_residuals.append(res)
all_residuals = np.concatenate(all_residuals)
global_std = np.std(all_residuals)
# Plot observations
obs_set_transformed.plot_all(fig=fig, ax=ax, ax_filt=ax_filt)
# Plot best-fit and residuals
for i, obs in enumerate(observations.observations):
res = best_fit[i].residuals(obs.flux)
if obs.is_photometric:
if not plot_native_model:
ax.scatter(best_fit[i].wave, best_fit[i].flux, marker='o', c = config.color, zorder=config.zorder)
axr.scatter(obs.wave, res / global_std, c = config.color, marker='o')
else:
if not plot_native_model:
ax.plot(best_fit[i].wave, best_fit[i].flux, color=config.color, linewidth=config.linewidth, zorder=config.zorder) # Best-fit
axr.plot(obs.wave, res / global_std, c=config.color) # Residuals
axr2.hist(res/global_std, orientation='horizontal', bins=100, color=config.color, alpha=0.8, density=True)
axr.set_xlabel(r'Wavelength ($\mu$m)')
axr.set_ylabel(r'Residuals ($\sigma$)')
axr.axhline(y=0, linestyle='--', color = 'lightgrey')
axr2.axis('off')
return fig, ax, ax_filt, axr, axr2
[docs]
def plot_ccf(self, rv_grid: np.ndarray, ccf: np.ndarray, acf: np.ndarray, ccf_star: np.ndarray | None = None, title: str = None) -> tuple[Figure, Axes]:
'''
Plot the Cross-Correlation Function (CCF).
Parameters
----------
rv_grid : np.ndarray
Grid of radial velocity values (in km/s)
ccf : np.ndarray
Corresponding ccf (cross-correlation) values
acf : np.ndarray
acf (aut-correlation) values
ccf_star : np.ndarray
ccf values with star speckles
Returns
-------
tuple[Figure, Axes]
Figure and Axes objects
Notes
-----
Authors: Bhavesh Rajpoot and Allan Denis
'''
self._logger.info(' Plotting CCF')
# Find best RV
best_idx = np.unravel_index(np.argmax(ccf), ccf.shape)
rv_peak = rv_grid[best_idx[0]]
# plot_ccf
fig, ax = plt.subplots(figsize=(10, 5))
ax.plot(rv_grid, ccf, label='CCF', color='blue')
ax.plot(rv_grid + rv_peak, acf, label='ACF', color='orange', linestyle='--')
if ccf_star is not None and np.any(ccf_star != 0):
ax.plot(rv_grid, ccf_star, label='Star CCF', color='red', alpha=0.5)
ax.axvline(rv_peak, color='grey', linestyle=':', label=f'RV = {rv_peak:.1f} km/s')
ax.set_xlabel('RV (km/s)')
ax.set_ylabel('CCF (SNR)')
if title is not None:
ax.set_title(f'CCF - {title}')
ax.legend()
return fig, ax
[docs]
def plot_rv_vsini_map(self, rv_grid: np.ndarray, vsini_grid: np.ndarray, logL_map: np.ndarray, title: str = None) -> tuple[list[Figure], list[Axes]]:
'''
Plot the RV vs v.sin(i) loglikelihood map.
Parameters
----------
rv_grid : np.ndarray
Grid of radial velocity values (in km/s)
ccf : np.ndarray
Corresponding ccf (cross-correlation) values
acf : np.ndarray
acf (aut-correlation) values
ccf_star : np.ndarray
ccf values with star speckles
Returns
-------
tuple[Figure, Axes]
Figure and Axes objects
Notes
-----
Authors: Bhavesh Rajpoot (adapted from Allan Denis)
'''
self._logger.info(' Computing RV-vsini map')
# Find best RV and vsini
best_idx = np.unravel_index(np.argmax(logL_map), logL_map.shape)
best_vsini = vsini_grid[best_idx[0]]
best_rv = rv_grid[best_idx[1]]
# plot rv/vsini map
fig, ax = plt.subplots(figsize=(8, 6))
extent = [rv_grid[0], rv_grid[-1], vsini_grid[0], vsini_grid[-1]]
im = ax.imshow(logL_map, aspect='auto', origin='lower', extent=extent, cmap='viridis')
ax.scatter(best_rv, best_vsini, marker='x', color='red', s=100, label=f'Best: RV={best_rv:.1f}, vsini={best_vsini:.1f}')
ax.set_xlabel('RV (km/s)')
ax.set_ylabel(r'v.sin(i) (km/s)')
if title is not None:
ax.set_title(f'RV - v.sin(i) map - {title}')
ax.legend()
fig.colorbar(im, ax=ax, label='log L')
return fig, ax