Source code for qsospec.plotting.local

"""Lightweight plotting helpers for qsospec result objects."""

from __future__ import annotations

from pathlib import Path
from typing import Dict, Optional

import numpy as np

from ..result import FitResult, LocalFitResult


def _setup_matplotlib():
    import matplotlib.pyplot as plt

    return plt


def _percentile_limits(values, percentiles=(1.0, 99.0), pad_fraction=0.08):
    arrays = [np.ravel(np.asarray(value, dtype=float)) for value in values if value is not None]
    if not arrays:
        return None
    data = np.concatenate(arrays)
    data = data[np.isfinite(data)]
    if data.size == 0:
        return None
    lo, hi = np.nanpercentile(data, percentiles)
    if not np.isfinite(lo) or not np.isfinite(hi):
        return None
    pad = (hi - lo) * pad_fraction if hi != lo else (abs(lo) * pad_fraction if lo != 0 else 1.0)
    return lo - pad, hi + pad


def _with_gap_breaks(wave, values, gap_factor=5.0):
    wave = np.asarray(wave, dtype=float)
    values = np.asarray(values, dtype=float)
    if wave.size < 3:
        return wave, values
    steps = np.diff(wave)
    finite_steps = steps[np.isfinite(steps) & (steps > 0)]
    if finite_steps.size == 0:
        return wave, values
    typical = float(np.nanmedian(finite_steps))
    if not np.isfinite(typical) or typical <= 0:
        return wave, values
    breaks = np.where(steps > gap_factor * typical)[0]
    if breaks.size == 0:
        return wave, values
    wave_out = wave.astype(float).copy()
    values_out = values.astype(float).copy()
    for index in breaks[::-1]:
        wave_out = np.insert(wave_out, index + 1, np.nan)
        values_out = np.insert(values_out, index + 1, np.nan)
    return wave_out, values_out


[docs] def plot_line_result( result: FitResult, output_path: Optional[str] = None, ax=None, show_components: bool = True, title: Optional[str] = None, ): """Plot one fitted local line-complex result. Returns the Matplotlib axes when ``output_path`` is not provided, otherwise saves the figure and returns the written path. """ plt = _setup_matplotlib() owns_figure = ax is None if owns_figure: fig, ax = plt.subplots(figsize=(7, 4), constrained_layout=True) else: fig = ax.figure if not result.success or result.wave_rest_fit.size == 0: ax.text(0.5, 0.5, result.message or "fit failed", ha="center", va="center", transform=ax.transAxes) ax.set_axis_off() else: if result.wave_rest_window.size: wave = result.wave_rest_window flux = result.flux_window model = result.model_window component_models = result.component_models_window or result.component_models else: wave = result.wave_rest_fit flux = result.flux_fit model = result.model component_models = result.component_models plot_window = result.metadata.get("plot_window") if plot_window is not None and len(plot_window) == 2: lo, hi = map(float, plot_window) view = (wave >= lo) & (wave <= hi) if np.any(view): wave = wave[view] flux = flux[view] model = model[view] component_models = {name: np.asarray(component)[view] for name, component in component_models.items()} else: lo = hi = None plot_wave, plot_flux = _with_gap_breaks(wave, flux) _, plot_model = _with_gap_breaks(wave, model) ax.plot(plot_wave, plot_flux, color="0.25", lw=0.9, label="data") ax.plot(plot_wave, plot_model, color="tab:blue", lw=1.2, label="model") if show_components: for name, component in component_models.items(): component_wave, component_values = _with_gap_breaks(wave, component) if name == "continuum": ax.plot(component_wave, component_values, color="tab:orange", lw=1.0, ls="--", label="continuum") else: ax.plot(component_wave, component_values, lw=0.9, ls=":", label=name) limits = _percentile_limits([flux, model]) if limits is not None: ax.set_ylim(*limits) if plot_window is not None and len(plot_window) == 2: ax.set_xlim(float(plot_window[0]), float(plot_window[1])) ax.set_xlabel("Rest wavelength [Angstrom]", fontsize=10) ax.set_ylabel( f"Flux density [{result.metadata.get('flux_unit', 'relative')}]", fontsize=10, ) ax.tick_params(labelsize=9) ax.legend(loc="best", fontsize=8) if title: ax.set_title(title, fontsize=11) if output_path is not None: path = Path(output_path) path.parent.mkdir(parents=True, exist_ok=True) fig.savefig(path, dpi=160) if owns_figure: plt.close(fig) return str(path) return ax
[docs] def plot_local_result( result: LocalFitResult, output_path: Optional[str] = None, show_components: bool = True, ): """Plot all windows from a local fit in stacked panels.""" plt = _setup_matplotlib() n_windows = max(len(result.window_results), 1) fig, axes = plt.subplots(n_windows, 1, figsize=(8, 3.2 * n_windows), squeeze=False, constrained_layout=True) unit = "input" for ax, (window, fit) in zip(axes.ravel(), result.window_results.items()): plot_line_result(fit, ax=ax, show_components=show_components, title=window) unit = fit.metadata.get("flux_unit", unit) ax.set_ylabel("") fig.supylabel(f"Flux density [{unit}]", fontsize=10) if output_path is not None: path = Path(output_path) path.parent.mkdir(parents=True, exist_ok=True) fig.savefig(path, dpi=160) plt.close(fig) return str(path) return axes.ravel()
[docs] def save_local_window_plots( result: LocalFitResult, output_dir: str, show_components: bool = True, ) -> Dict[str, str]: """Write one PNG per local-fit window and return ``{window: path}``.""" out = Path(output_dir) out.mkdir(parents=True, exist_ok=True) files = {} for window, fit in result.window_results.items(): files[window] = plot_line_result( fit, output_path=str(out / f"{window}_qsospec.png"), show_components=show_components, title=window, ) return files