Source code for qsospec.fitting.local

"""Public qsospec fitting API."""

from __future__ import annotations

import numpy as np
from typing import Any, Dict, List, Optional, Tuple

from ..config import LineComplexConfig, LocalFitConfig
from ..solvers.least_squares import run_least_squares
from ..parameters import pack_line_complex_parameters
from ..residuals import iron_basis_vector, model_and_residual, model_components, model_vector
from ..result import FitResult, LocalFitResult
from ..spectrum import Spectrum, require_rest_frame_flux
from ..warnings import FitWarning


def _prepare_iron_component(
    config: LineComplexConfig,
    wave_fit: np.ndarray,
) -> Tuple[Optional[np.ndarray], Optional[Any], Optional[str], Optional[Dict[str, Any]], List[FitWarning]]:
    """Load and prepare an optional iron template once for a local fit."""

    if config.iron is None or not config.iron.enabled:
        return None, None, None, None, []

    from ..templates.iron import prepare_iron_template
    from ..templates.registry import load_iron_template

    template = load_iron_template(
        config.iron.template,
        template_path=config.iron.template_path,
        normalization=config.iron.normalization,
    )
    prepared = prepare_iron_template(
        template,
        wave_fit,
        tuple(config.window),
        fwhm_kms=config.iron.fwhm_kms,
    )
    metadata = {
        "enabled": True,
        "template": template.name,
        "template_requested": config.iron.template,
        "template_source_path": template.source_path,
        "template_reference": template.reference,
        "template_coverage_min": float(template.coverage[0]) if template.coverage else float("nan"),
        "template_coverage_max": float(template.coverage[1]) if template.coverage else float("nan"),
        "normalization": template.normalization,
        "fwhm_initial_kms": float(config.iron.fwhm_kms),
        "fwhm_bounds": tuple(config.iron.fwhm_bounds),
        "fwhm_param": "iron.fwhm_kms" if prepared.has_overlap else None,
        "has_overlap": bool(prepared.has_overlap),
        "amp_param": "iron.amp" if prepared.has_overlap else None,
        "notes": list(template.notes),
    }
    basis = prepared.basis if prepared.has_overlap else None
    fit_template = template if prepared.has_overlap else None
    return basis, fit_template, template.name, metadata, list(prepared.warnings)


def _window_name(config: LineComplexConfig, index: int = 0) -> str:
    return config.name or f"line_complex_{index}"


def _flux_scale_warning(spectrum: Spectrum) -> List[FitWarning]:
    if spectrum.flux_density_scale_to_cgs is not None:
        return []
    return [
        FitWarning(
            code="flux_scale_unknown_cgs_not_reported",
            message="Flux-density scale to cgs is unknown; cgs line fluxes are not reported.",
        )
    ]


def _window_selector(wave_rest: np.ndarray, windows: List[Tuple[float, float]]) -> np.ndarray:
    selector = np.zeros_like(wave_rest, dtype=bool)
    for lo, hi in windows:
        selector |= (wave_rest >= float(lo)) & (wave_rest <= float(hi))
    return selector


def _fit_pixel_mask(spectrum: Spectrum, config: LineComplexConfig) -> np.ndarray:
    wave_rest = spectrum.wave_rest
    base_windows = list(config.fit_windows) if config.fit_windows is not None else [tuple(config.window)]
    fit_mask = spectrum.valid_mask & _window_selector(wave_rest, base_windows)
    if config.mask_windows:
        fit_mask &= ~_window_selector(wave_rest, list(config.mask_windows))
    lo, hi = map(float, config.window)
    fit_mask &= (wave_rest >= lo) & (wave_rest <= hi)
    return fit_mask


[docs] def fit_line_complex( spectrum: Spectrum, config: LineComplexConfig, jacobian: Optional[str] = None, ) -> FitResult: """Fit one local Gaussian line complex on the spectrum rest-frame grid.""" require_rest_frame_flux(spectrum) wave_rest = spectrum.wave_rest lo, hi = map(float, config.window) fit_mask = _fit_pixel_mask(spectrum, config) window_mask = spectrum.valid_mask & (wave_rest >= lo) & (wave_rest <= hi) n_pixels = int(np.count_nonzero(fit_mask)) if n_pixels == 0: raise ValueError("No valid pixels fall inside the requested line-complex window.") wave_fit = wave_rest[fit_mask] flux_fit = spectrum.flux[fit_mask] err_fit = spectrum.err[fit_mask] iron_basis, iron_template, iron_template_name, iron_metadata, iron_warnings = _prepare_iron_component(config, wave_fit) packed = pack_line_complex_parameters( config, wave_fit, flux_fit=flux_fit, iron_basis=iron_basis, iron_template=iron_template, iron_template_name=iron_template_name, ) if n_pixels <= packed.initial.size: raise ValueError( f"Too few valid pixels ({n_pixels}) for {packed.initial.size} fitted parameters." ) jac_mode = config.jacobian if jacobian is None else jacobian optimizer_result = run_least_squares( packed, wave_fit, flux_fit, err_fit, jacobian=jac_mode, max_nfev=config.max_nfev, ) model, residual = model_and_residual(optimizer_result.x, packed, wave_fit, flux_fit, err_fit) wave_window = wave_rest[window_mask] flux_window = spectrum.flux[window_mask] err_window = spectrum.err[window_mask] model_window = model_vector(optimizer_result.x, packed, wave_window) residual_window = (flux_window - model_window) / err_window fit_used_window = fit_mask[window_mask] chi2 = float(np.sum(residual * residual)) dof = int(max(wave_fit.size - optimizer_result.x.size, 0)) reduced_chi2 = float(chi2 / dof) if dof > 0 else float("nan") param_values = packed.unpack(optimizer_result.x) warnings = _flux_scale_warning(spectrum) warnings.extend(iron_warnings) if not optimizer_result.success: warnings.append( FitWarning( code="fit_failed", message=str(optimizer_result.message), context={"window": config.name, "status": int(optimizer_result.status)}, ) ) metadata = spectrum.metadata.to_dict() metadata.update( { "window_name": config.name, "window": tuple(config.window), "fit_windows": list(config.fit_windows) if config.fit_windows is not None else [tuple(config.window)], "mask_windows": list(config.mask_windows), "plot_window": tuple(config.plot_window) if config.plot_window is not None else tuple(config.window), "line_center": float(config.center), "jacobian": jac_mode, } ) if iron_metadata is not None: amp = param_values.get("iron.amp") fwhm = param_values.get("iron.fwhm_kms") iron_flux_input = float("nan") iron_flux_cgs = float("nan") final_iron_basis = iron_basis_vector(optimizer_result.x, packed, wave_fit) if amp is not None and final_iron_basis is not None: iron_model = float(amp) * final_iron_basis iron_flux_input = float(np.trapezoid(iron_model, wave_fit)) scale = spectrum.metadata.flux_scale iron_flux_cgs = iron_flux_input * float(scale) if scale is not None else float("nan") iron_metadata.update( { "iron_amp": float(amp) if amp is not None else float("nan"), "fwhm_kms": float(fwhm) if fwhm is not None else float("nan"), "iron_flux_input": iron_flux_input, "iron_flux_cgs": iron_flux_cgs, } ) metadata["iron"] = iron_metadata return FitResult( success=bool(optimizer_result.success), status=int(optimizer_result.status), message=str(optimizer_result.message), theta=np.asarray(optimizer_result.x, dtype=float), param_names=list(packed.names), param_values=param_values, chi2=chi2, dof=dof, reduced_chi2=reduced_chi2, model=model, residual=residual, wave_rest_fit=wave_fit, flux_fit=flux_fit, err_fit=err_fit, wave_rest_window=wave_window, flux_window=flux_window, err_window=err_window, model_window=model_window, residual_window=residual_window, fit_used_window=fit_used_window, component_models=model_components(optimizer_result.x, packed, wave_fit), component_models_window=model_components(optimizer_result.x, packed, wave_window), warnings=warnings, metadata=metadata, optimizer_result=optimizer_result, )
def _n_parameters(config: LineComplexConfig) -> int: n = 3 * len(config.components) if config.local_continuum == "constant": n += 1 elif config.local_continuum == "linear": n += 2 if config.iron is not None and config.iron.enabled: n += 2 return n def _validate_local_window( spectrum: Spectrum, config: LineComplexConfig, local_config: LocalFitConfig, window_name: str, ) -> Tuple[bool, List[FitWarning]]: wave_rest = spectrum.wave_rest valid = spectrum.valid_mask warnings: List[FitWarning] = [] if not np.any(valid): warnings.append( FitWarning( code="all_pixels_invalid", message="No finite pixels with positive errors are available.", severity="error", context={"window": window_name}, ) ) return False, warnings valid_wave = wave_rest[valid] coverage_min = float(np.nanmin(valid_wave)) coverage_max = float(np.nanmax(valid_wave)) lo, hi = map(float, config.window) context = { "window": window_name, "requested_window": (lo, hi), "coverage": (coverage_min, coverage_max), } if hi < coverage_min or lo > coverage_max: warnings.append( FitWarning( code="window_not_covered", message="Requested local window is outside the valid rest-frame spectrum coverage.", severity="error", context=context, ) ) return False, warnings if lo < coverage_min or hi > coverage_max: warnings.append( FitWarning( code="window_not_covered", message="Requested local window is only partially covered by the valid rest-frame spectrum.", context=context, ) ) if config.center < lo or config.center > hi or config.center < coverage_min or config.center > coverage_max: warnings.append( FitWarning( code="line_center_outside_coverage", message="Nominal line center is outside the requested window or valid spectrum coverage.", severity="error", context={**context, "line_center": float(config.center)}, ) ) return False, warnings fit_mask = _fit_pixel_mask(spectrum, config) n_pixels = int(np.count_nonzero(fit_mask)) n_required = max(int(local_config.require_min_pixels), _n_parameters(config) + 1) if n_pixels < n_required: warnings.append( FitWarning( code="window_too_few_pixels", message=f"Only {n_pixels} valid pixels are available; at least {n_required} are required.", severity="error", context={**context, "n_pixels": n_pixels, "n_required": n_required}, ) ) return False, warnings if local_config.edge_buffer > 0: covered_lo = max(lo, coverage_min) covered_hi = min(hi, coverage_max) if (config.center - covered_lo) <= local_config.edge_buffer or (covered_hi - config.center) <= local_config.edge_buffer: warnings.append( FitWarning( code="line_center_near_edge", message="Nominal line center is close to the covered local-window edge.", context={**context, "line_center": float(config.center), "edge_buffer": local_config.edge_buffer}, ) ) return True, warnings
[docs] def fit_local(spectrum: Spectrum, config: LocalFitConfig) -> LocalFitResult: """Fit one or more local line-complex windows independently.""" require_rest_frame_flux(spectrum) window_results = {} all_warnings: List[FitWarning] = [] for index, window_config in enumerate(config.windows): name = _window_name(window_config, index=index) can_fit, warnings = _validate_local_window(spectrum, window_config, config, name) metadata = spectrum.metadata.to_dict() metadata.update({"window_name": name, "mode": config.mode, "window": tuple(window_config.window)}) if not can_fit: all_warnings.extend(warnings) window_results[name] = FitResult.failed( "Local window validation failed.", warnings=warnings, metadata=metadata, ) continue try: result = fit_line_complex(spectrum, window_config) result.metadata["window_name"] = name if warnings: result.warnings.extend(warnings) window_results[name] = result all_warnings.extend(result.warnings) except Exception as exc: code = getattr(exc, "code", "fit_failed") warning = FitWarning( code=code, message=str(exc), severity="error", context={"window": name}, ) all_warnings.extend(warnings) all_warnings.append(warning) window_results[name] = FitResult.failed(str(exc), warnings=warnings + [warning], metadata=metadata) return LocalFitResult( success=any(result.success for result in window_results.values()), window_results=window_results, warnings=all_warnings, metadata={"mode": config.mode, "n_windows": len(config.windows)}, )