Source code for ForMoSA.nested_sampling.plotting

import corner
import logging
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
import matplotlib.gridspec as gridspec
from matplotlib.axes._axes import Axes
import matplotlib.patheffects as path_effects
from matplotlib.ticker import AutoMinorLocator
from matplotlib.patches import RegularPolygon
from matplotlib.path import Path as MplPath
from matplotlib.projections import register_projection
from matplotlib.projections.polar import PolarAxes
from matplotlib.spines import Spine
from matplotlib.transforms import Affine2D

from ForMoSA.core.config import PLOTS_CONFIG, MAIN_PLOT
from ForMoSA.core.errors import ForMoSAError
from ForMoSA.core.loggings import setup_logging
from ForMoSA.transform.observed import ObservedModel
from ForMoSA.nested_sampling.results import NSResults
from ForMoSA.observation.observation_set import ObservationSet

# Cache so we don't re-register the projection on every plot call
_REGISTERED_RADARS: set[int] = set()


def _radar_polygon_factory(num_vars: int) -> tuple[np.ndarray, str]:
    """
    Register (once) a polygon-frame radar projection for `num_vars` spokes.

    Authors: Bhavesh Rajpoot (Adapted from matplotlib's radar chart gallery example:
    https://matplotlib.org/stable/gallery/specialty_plots/radar_chart.html)
    """
    name  = f'_formosa_radar_{num_vars}'
    theta = np.linspace(0.0, 2.0 * np.pi, num_vars, endpoint=False)

    if num_vars in _REGISTERED_RADARS:
        return theta, name
    class RadarTransform(PolarAxes.PolarTransform):
        def transform_path_non_affine(self, path):
            if path._interpolation_steps > 1:
                path = path.interpolated(num_vars)
            return MplPath(self.transform(path.vertices), path.codes)
    class RadarAxes(PolarAxes):
        PolarTransform = RadarTransform

        def __init__(self, *args, **kwargs):
            super().__init__(*args, **kwargs)
            self.set_theta_zero_location('N')   # first axis at top

        def fill(self, *args, closed=True, **kwargs):
            return super().fill(closed=closed, *args, **kwargs)

        def plot(self, *args, **kwargs):
            lines = super().plot(*args, **kwargs)
            for line in lines:
                x, y = line.get_data()
                if x[0] != x[-1]:
                    line.set_data(np.append(x, x[0]), np.append(y, y[0]))

        def _gen_axes_patch(self):
            return RegularPolygon((0.5, 0.5), num_vars, radius=0.5, edgecolor='none')

        def _gen_axes_spines(self):
            spine = Spine(axes=self, spine_type='circle',
                          path=MplPath.unit_regular_polygon(num_vars))
            spine.set_transform(Affine2D().scale(0.5).translate(0.5, 0.5) + self.transAxes)
            return {'polar': spine}

    RadarAxes.name = name
    register_projection(RadarAxes)
    _REGISTERED_RADARS.add(num_vars)
    return theta, name

[docs] class Plotting(object): ''' Class of visualisation of the results of the nested sampling. Parameters ---------- results : NSResults Instance of class NSResults logger : Logger Logger used log_level : str Level of the Logger Notes ----- Authors: Paulina Palma-Bifani, Matthieu Ravet and Allan Denis ''' def __init__(self, results: NSResults, logger: logging.Logger, log_level: str = 'INFO') -> None: self._logger = logger if logger is not None else setup_logging(log_level, name='Plotting') self._ns_results = results if not isinstance(results, NSResults): raise ForMoSAError(f'<Wrong type for results: {type(results)}. Expected a NSResults>', self.logger) # ================= # Representation # ================= def __repr__(self): return '<Plotting>' # ================= # Properties # ================= @property def logger(self) -> logging.Logger: """Logger.""" return self._logger @property def ns_results(self) -> NSResults: """Instance of classe NSResults.""" return self._ns_results # ================= # Methods # =================
[docs] def plot_corner(self) -> Figure: ''' Corner plot the posterior samples from the nested sampling results. Parameters ---------- config : CornerPlotConfig Instance of class CornerPlotConfig Returns ------- matplotlib.figure.Figure Figure containin corner plots. Notes ----- Authors: Paulina Palma-Bifani, Matthieu Ravet and Allan Denis ''' self._logger.info(' Plotting Corner plot') samples, weights = self.ns_results.samples[self.ns_results.burn_in:], self.ns_results.weights[self.ns_results.burn_in:] # Get config for Corner plot config = PLOTS_CONFIG.CornerPlot # Get corner arguments from the config corner_kwargs = config.to_dict corner_kwargs['labels'] = self.ns_results.free_parameters corner_kwargs['weights'] = weights corner_kwargs['range'] = [0.99999 for i in self.ns_results.free_parameters] # Create the figure fig = corner.corner(samples, **corner_kwargs) return fig
[docs] def plot_chains(self) -> tuple[Figure, Axes]: ''' Plot the chains of the samples results. Parameters ---------- Returns: -------- tuple[matplotlib.figure.Figure, matplotlib.axes._axes.Axes] Tuple containing Figure and Ax objects Notes ----- Authors: Paulina Palma-Bifani, Matthieu Ravet and Allan Denis ''' self._logger.info(' Plotting posterior chains for each parameter.') samples, weights = self.ns_results.samples, self.ns_results.weights samples = self.ns_results.samples param_best_values = list(self.ns_results.median_parameters.values()) n_params = samples.shape[1] n_rows = (n_params + 1) // 2 # Get config for chains plot config = PLOTS_CONFIG.ChainsPlot fig, axs = plt.subplots(n_rows, 2, figsize=config.figsize) axs = axs.flatten() for idx in range(n_params): ax = axs[idx] param_name = self.ns_results.free_parameters[idx] ax.plot(samples[:, idx], color=config.color_chains, alpha=config.alpha_chains) ax.set_ylabel(param_name) ax.axvline(self.ns_results.burn_in, linestyle=config.linestyle_burn_in, color=config.color_plot_burn_in) ax.text(x = config.text_burn_in[0], y = config.text_burn_in[1], s='burn in', color=config.color_text_burn_in, transform=ax.transAxes, fontsize=config.fontsize_burn_in) if config.show_weights: ax_w = ax.twinx() ax_w.plot(weights, config.color_plot_weights, alpha=config.alpha_weights) ax_w.set_yticks([]) ax_w.text(x=config.text_weights[0], y=config.text_weights[1], s='weights', color=config.color_text_weights, transform=ax_w.transAxes, fontsize=config.fontsize_weights) if config.plot_best_value: ax.axhline(param_best_values[idx], color=config.color_best_value, linestyle=config.linestyle_best_value) for idx in range(n_params, len(axs)): fig.delaxes(axs[idx]) return fig, axs[:n_params]
[docs] def plot_radars(self) -> tuple[Figure, Axes]: ''' Plot spider plot of the samples results. Parameters ---------- Returns ------- tuple[matplotlib.figure.Figure, matplotlib.axes._axes.Axes] Notes ----- Authors: Bhavesh Rajpoot (adapted from Paulina Palma-Bifani, Matthieu Ravet, Allan Denis) ''' self._logger.info(' Plotting spider plot of the chains') samples = self.ns_results.samples[self.ns_results.burn_in:] weights = self.ns_results.weights[self.ns_results.burn_in:] params = self.ns_results.free_parameters N = len(params) config = PLOTS_CONFIG.RadarPlot # Step-size registry for known parameters _STEPS: dict[str, float] = { 'rv': 5.0, # km/s 'vsini': 10.0, # km/s 'd': 1.0, # pc 'r': 1.0, # Rjup 'teff': 100.0, # K 'logg': 0.5, # dex 'feh': 0.5, '[m/h]': 0.5, 'co': 0.1, } def _nice(dr: float) -> float: if dr <= 0.0: return 1.0 raw = dr / 5.0 mag = 10.0 ** np.floor(np.log10(abs(raw))) r = raw / mag if r < 1.5: return 1.0 * mag elif r < 3.5: return 2.0 * mag elif r < 7.5: return 5.0 * mag else: return 10.0 * mag def _step(name: str, dr: float) -> float: return _STEPS.get(name.lower().strip(), _nice(dr)) def _snap(lo: float, hi: float, s: float) -> tuple[float, float]: return np.floor(lo / s) * s, np.ceil(hi / s) * s def _norm(v: float, lo: float, hi: float) -> float: return (v - lo) / (hi - lo) if hi != lo else 0.0 def _fmt(v: float) -> str: if abs(v) >= 1000: return f'{v:.0f}' elif abs(v) >= 10: return f'{v:.1f}' else: return f'{v:.2f}' # Weighted quantiles q_lo, q_med, q_hi = [], [], [] for i in range(samples.shape[1]): q_lo.append( self.ns_results._weighted_quantile(samples[:, i], weights, config.quantiles[0])) q_med.append(self.ns_results._weighted_quantile(samples[:, i], weights, 0.5)) q_hi.append( self.ns_results._weighted_quantile(samples[:, i], weights, config.quantiles[1])) q_lo, q_med, q_hi = np.asarray(q_lo), np.asarray(q_med), np.asarray(q_hi) # Per-axis snapped bounds: based on CI width, not sample range # This is the key change that removes the large empty space around # tight posteriors. K_MARGIN = 3 puts the median at ~50% radius and # the 1-sigma band at ~33% to ~66% radius before snapping. K_MARGIN = 3.0 ax_lo, ax_hi, step_sizes = [], [], [] for i, p in enumerate(params): width = q_hi[i] - q_lo[i] if width <= 0: width = max(float(np.std(samples[:, i])), 1e-6) raw_lo = q_med[i] - K_MARGIN * width raw_hi = q_med[i] + K_MARGIN * width s = _step(p, raw_hi - raw_lo) lo, hi = _snap(raw_lo, raw_hi, s) ax_lo.append(lo); ax_hi.append(hi); step_sizes.append(s) # Normalize quantiles to [0, 1]; clip prevents excursions outside the ring med_n = np.clip([_norm(q_med[i], ax_lo[i], ax_hi[i]) for i in range(N)], 0.0, 1.0) lo_n = np.clip([_norm(q_lo[i], ax_lo[i], ax_hi[i]) for i in range(N)], 0.0, 1.0) hi_n = np.clip([_norm(q_hi[i], ax_lo[i], ax_hi[i]) for i in range(N)], 0.0, 1.0) # Polygon-frame polar projection theta, proj = _radar_polygon_factory(N) fig = plt.figure(figsize=(8, 8)) ax = fig.add_subplot(projection=proj) # Grid rings & limits N_RINGS = 4 ring_levels = np.linspace(0.0, 1.0, N_RINGS + 1)[1:] # [0.25, 0.5, 0.75, 1.0] ax.set_rgrids(ring_levels, labels=[''] * len(ring_levels)) # hide default radial labels ax.set_ylim(0.0, 1.0) ax.grid(True, color='gray', linewidth=0.8, alpha=0.5, linestyle='--') ax.spines['polar'].set_color('#808183') ax.spines['polar'].set_linewidth(1.2) # Uncertainty band ax.fill_between(theta, lo_n, hi_n, color=config.color_uncertainty, alpha=config.alpha_fill, zorder=2) # Median polygon + markers ax.plot(theta, med_n, color=config.color_radar, linewidth=2.0, zorder=3) for k in range(N): ax.scatter(theta[k], med_n[k], color='white', s=config.size_quantiles + 40, zorder=4, edgecolors='none') ax.scatter(theta[k], med_n[k], color=config.color_quantiles, s=config.size_quantiles, zorder=5, edgecolors='white', linewidths=config.lw_quantiles) # Parameter name labels at spoke tips (native matplotlib) ax.set_thetagrids(np.degrees(theta), labels=params, fontsize=config.fontsize_names) ax.tick_params(axis='x', pad=50, colors='#24292E') # Value annotations: between polygon and parameter name # Fixed radius along each spoke — independent of the data value, so # the labels never crowd the centre or escape the outer ring. VAL_R = 1.12 for k in range(N): dlo = q_med[k] - q_lo[k] dhi = q_hi[k] - q_med[k] lbl = f'${_fmt(q_med[k])}_{{-{_fmt(dlo)}}}^{{+{_fmt(dhi)}}}$' ax.text(theta[k], VAL_R, lbl, ha='center', va='center', fontsize=config.fontsize_ticks, fontweight='600', color=config.color_radar, zorder=10, clip_on=False, bbox=dict(boxstyle='round,pad=0.3', facecolor='white', edgecolor='none', alpha=0.85)) # Per-spoke tick labels: ONLY at grid-ring radii # Fixes the original "label every step" overflow by reusing the same # 4 radii used for the visual grid. if getattr(config, 'show_ticks', False): for k in range(N): for rl in ring_levels[:-1]: # skip outermost (overlaps with VAL_R) actual = ax_lo[k] + rl * (ax_hi[k] - ax_lo[k]) ax.text(theta[k], rl, _fmt(actual), ha='center', va='center', fontsize=max(config.fontsize_ticks - 2, 11), color=config.color_ticks, zorder=6, clip_on=False, bbox=dict(boxstyle='round,pad=0.15', facecolor='white', edgecolor='none', alpha=0.5)) fig.tight_layout() return fig, ax
[docs] def plot_fit(self, observations: ObservationSet, best_fit: list[ObservedModel], figsize: tuple[float, float] = (18, 8), plot_native_model: bool = False, native_model: ObservedModel | None = None) -> tuple[Figure, Axes, Axes, Axes, Axes]: ''' Plot best fit Parameters ---------- observations : ObservationSet Instance of class ObservationSet best_fit : list[ObservedModel] List of instances of class ObservedModel corresponding to the best-fit model for each observation figsize : tuple[float, float] Size of the figure plot_native_model : bool Whether to plot the native model native_model : ObservedModel As instance of ObservedModel Returns ------- tuple[Figure, Axes, Axes, Axes, Axes] Figure and ax objects Notes ----- Authors: Paulina Palma-Bifani, Matthieu Ravet and Allan Denis ''' self._logger.info(' Plotting best fit and residuals') # Initial checks if not isinstance(best_fit, list) or len(best_fit) != observations.n_observations: raise ForMoSAError(f'best_fit must be a list of length {observations.n_observations}', self.logger) if plot_native_model is True: if not isinstance(native_model, ObservedModel): raise ForMoSAError(f'If you want to plot the native model, native_model must be an instance of ObservedModel. Got {type(native_model)}') # Get config for best fit config = PLOTS_CONFIG.BestFitPlot main_config = MAIN_PLOT # obs_set_transformed = ObservationSet(self.logger) # for i, obs in enumerate(observations.observations): # # Create a copy of the observations to optionally remove component estimated by high-contrast module # obs_transformed = copy.deepcopy(obs) # obs_transformed._flux -= best_fit[i].component # obs_set_transformed.add_observation(obs_transformed) # Reserve top rows for filter axis only when photometry is present ax_row_start = 2 if observations.has_photometry else 0 fig = plt.figure(figsize=figsize) gs = gridspec.GridSpec(9, 11) # Main axis for observations + best-fit ax = fig.add_subplot(gs[ax_row_start:7, 0:10]) # Axis for photometric filters ax_filt = None if observations.has_photometry: ax_filt = fig.add_subplot(gs[0:2, 0:10], sharex=ax) # Residuals and histogram axes axr = fig.add_subplot(gs[7:9, 0:10], sharex=ax) axr2 = fig.add_subplot(gs[7:9, 10:11], sharey=axr) # Plot native model if required if plot_native_model: ax.plot(native_model.wave, native_model.flux, color=config.color_fit, linewidth=config.linewidth, zorder=config.zorder, label='Best fit native model') # concatenate all residuals first to compute a global standard deviation for normalization, # which is crucial for a consistent residuals plot across different observations all_residuals = [] for i, obs in enumerate(observations.observations): res = best_fit[i].residuals(obs.flux) all_residuals.append(res) all_residuals = np.concatenate(all_residuals) global_std = np.std(all_residuals) # Plot observations # obs_set_transformed.plot_all(fig=fig, ax=ax, ax_filt=ax_filt) observations.plot_all(fig=fig, ax=ax, ax_filt=ax_filt) # Plot best-fit and residuals for i, obs in enumerate(observations.observations): # Compute residuals and normalize by global std res_norm = best_fit[i].residuals(obs.flux) / global_std # Plot best-fit and residuals for photometric data if obs.is_photometric: if not plot_native_model: # For photometric data, we only plot the best-fit as scatter points ax.scatter(best_fit[i].wave, best_fit[i].total_flux, #best_fit[i].flux, marker='o', c = config.color_fit, zorder=config.zorder, label='Best fit') # Plot residuals as scatter points for photometric data axr.scatter(obs.wave, res_norm, c = config.color_residuals, marker='o') # Plot best-fit and residuals for spectroscopic data else: if not plot_native_model: # For spectroscopic data, we plot the best-fit as a line ax.plot(best_fit[i].wave, best_fit[i].total_flux, #best_fit[i].flux, color=config.color_fit, linewidth=config.linewidth, zorder=config.zorder, label='Best fit') # Plot residuals as a line for spectroscopic data axr.plot(obs.wave, res_norm, c=config.color_residuals, linewidth=config.linewidth) axr2.hist(res_norm, orientation='horizontal', bins=60, color=config.color_residuals, alpha=0.8, density=True) axr.set_xlabel(r'Wavelength ($\mu$m)') ax.set_xlabel(None) axr.set_ylabel(r'Residuals ($\sigma$)') axr.axhline(y=0, linestyle='--', color='grey') # Minor ticks if main_config.minor_ticks: axr.xaxis.set_minor_locator(AutoMinorLocator(main_config.nb_minor_ticks)) axr.yaxis.set_minor_locator(AutoMinorLocator(main_config.nb_minor_ticks)) # Remove axis for axr2 axr2.axis('off') # Re-render the main legend so the 'Best fit' line is included handles, labels = ax.get_legend_handles_labels() if handles: ax.legend(handles=handles, labels=labels, frameon=False, loc='upper right', fontsize=MAIN_PLOT.legend_fontsize) fig.tight_layout() return fig, ax, ax_filt, axr, axr2
[docs] def plot_ccf(self, rv_grid: np.ndarray, ccf: np.ndarray, acf: np.ndarray, ccf_star: np.ndarray | None = None, title: str = None) -> tuple[Figure, Axes]: ''' Plot the Cross-Correlation Function (CCF). Parameters ---------- rv_grid : np.ndarray Grid of radial velocity values (in km/s) ccf : np.ndarray Corresponding ccf (cross-correlation) values acf : np.ndarray acf (aut-correlation) values ccf_star : np.ndarray ccf values with star speckles Returns ------- tuple[Figure, Axes] Figure and Axes objects Notes ----- Authors: Bhavesh Rajpoot and Allan Denis ''' self._logger.info(' Plotting CCF') # Find best RV best_idx = np.unravel_index(np.argmax(ccf), ccf.shape) rv_peak = rv_grid[best_idx[0]] # plot_ccf fig, ax = plt.subplots(figsize=(10, 5)) ax.plot(rv_grid, ccf, label='CCF', color='blue') ax.plot(rv_grid + rv_peak, acf, label='ACF', color='orange', linestyle='--') if ccf_star is not None and np.any(ccf_star != 0): ax.plot(rv_grid, ccf_star, label='Star CCF', color='red', alpha=0.5) ax.axvline(rv_peak, color='grey', linestyle=':', label=f'RV = {rv_peak:.1f} km/s') ax.set_xlabel('RV (km/s)') ax.set_ylabel('CCF (SNR)') if title is not None: ax.set_title(f'CCF - {title}') ax.legend() return fig, ax
[docs] def plot_rv_vsini_map(self, rv_grid: np.ndarray, vsini_grid: np.ndarray, logL_map: np.ndarray, title: str = None) -> tuple[list[Figure], list[Axes]]: ''' Plot the RV vs v.sin(i) loglikelihood map. Parameters ---------- rv_grid : np.ndarray Grid of radial velocity values (in km/s) ccf : np.ndarray Corresponding ccf (cross-correlation) values acf : np.ndarray acf (aut-correlation) values ccf_star : np.ndarray ccf values with star speckles Returns ------- tuple[Figure, Axes] Figure and Axes objects Notes ----- Authors: Bhavesh Rajpoot (adapted from Allan Denis) ''' self._logger.info(' Computing RV-vsini map') # Find best RV and vsini best_idx = np.unravel_index(np.argmax(logL_map), logL_map.shape) best_vsini = vsini_grid[best_idx[0]] best_rv = rv_grid[best_idx[1]] # plot rv/vsini map fig, ax = plt.subplots(figsize=(8, 6)) extent = [rv_grid[0], rv_grid[-1], vsini_grid[0], vsini_grid[-1]] im = ax.imshow(logL_map, aspect='auto', origin='lower', extent=extent, cmap='viridis') ax.scatter(best_rv, best_vsini, marker='x', color='red', s=100, label=f'Best: RV={best_rv:.1f}, vsini={best_vsini:.1f}') ax.set_xlabel('RV (km/s)') ax.set_ylabel(r'v.sin(i) (km/s)') if title is not None: ax.set_title(f'RV - v.sin(i) map - {title}') ax.legend() fig.colorbar(im, ax=ax, label='log L') return fig, ax