Source code for qsospec.workflows.host_workflow

"""Optional pPXF host subtraction before qsospec fitting."""

from __future__ import annotations

from dataclasses import dataclass, replace
from typing import Any, Dict, Optional, Sequence, Tuple, Union

import numpy as np

from ..fitting.local import fit_local
from ..config import (
    GalacticExtinctionConfig,
    GlobalContinuumConfig,
    HalphaComplexConfig,
    HbetaComplexConfig,
    LyaNVComplexConfig,
    LocalFitConfig,
    MgIIComplexConfig,
    UncertaintyConfig,
)
from ..extinction import correct_spectrum_data
from ..fitting.global_fit import fit_global_lines
from ..complex_recipes import ComplexRecipe
from ..global_result import WorkflowResult
from ..result import LocalFitResult
from ..spectrum import Spectrum
from ..warnings import FitWarning


def _host_decomp_decision(requested: bool, redshift: Optional[float]) -> Tuple[bool, Optional[str]]:
    """Resolve the object-level pPXF redshift gate."""

    if not requested:
        return False, None
    try:
        value = float(redshift)
    except (TypeError, ValueError):
        return False, "missing_redshift"
    if not np.isfinite(value):
        return False, "missing_redshift"
    if value >= 1.2:
        return False, "redshift_at_or_above_1.2"
    return True, None


[docs] @dataclass class HostWorkflowResult: """Result of optional host subtraction followed by a qsospec fit.""" total_spectrum: Spectrum fit_spectrum: Spectrum local_result: LocalFitResult host_decomp_enabled: bool host_fit: Optional[Any] = None host_sed: Optional[Any] = None host_model_on_quasar_grid: Optional[np.ndarray] = None host_subtracted_flux: Optional[np.ndarray] = None host_warnings: Optional[list] = None metadata: Optional[Dict[str, Any]] = None
def _good_mask_from_spectrum_data(spectrum_data: Any, extra_mask: Optional[np.ndarray] = None) -> np.ndarray: wave = np.asarray(spectrum_data.wave_obs, dtype=float) flux = np.asarray(spectrum_data.flux, dtype=float) err = np.asarray(spectrum_data.uncertainty(), dtype=float) good = np.isfinite(wave) & np.isfinite(flux) & np.isfinite(err) & (wave > 0) & (err > 0) if spectrum_data.ivar is not None: ivar = np.asarray(spectrum_data.ivar, dtype=float) good &= np.isfinite(ivar) & (ivar > 0) if spectrum_data.mask is not None: good &= np.asarray(spectrum_data.mask) == 0 if extra_mask is not None: good &= np.asarray(extra_mask, dtype=bool) return good def _spectrum_from_arrays( wave_obs: np.ndarray, flux: np.ndarray, err: np.ndarray, redshift: float, mask: Optional[np.ndarray], source: str, spectrum_data: Optional[Any] = None, ) -> Spectrum: source_metadata = ( spectrum_data.metadata.get("spectrum_metadata") if spectrum_data is not None else None ) extinction = ( dict(spectrum_data.metadata.get("galactic_extinction", {})) if spectrum_data is not None else {} ) if spectrum_data is not None: base_metadata = dict(source_metadata or {}) for key in ( "flux_unit", "flux_scale", "flux_frame", "rest_frame_conversion", ): if key in spectrum_data.metadata: base_metadata[key] = spectrum_data.metadata[key] base_metadata.update( { "source": source, "ra": spectrum_data.ra, "dec": spectrum_data.dec, "galactic_extinction_corrected": extinction.get("status") in ( "applied", "declared_corrected", "caller_preprocessed", ), "galactic_extinction": extinction, } ) else: base_metadata = source_metadata return Spectrum.from_arrays( wave_obs, flux, err=err, z=float(redshift), wave_frame="observed", mask=mask, survey=None if base_metadata is not None else "desi", source=source, ra=None if spectrum_data is None else spectrum_data.ra, dec=None if spectrum_data is None else spectrum_data.dec, galactic_extinction_corrected=extinction.get("status") in ( "applied", "declared_corrected", "caller_preprocessed", ), galactic_extinction=extinction, metadata=base_metadata, ) def _spectrum_from_spectrum_data(spectrum_data: Any, source: str) -> Spectrum: good = _good_mask_from_spectrum_data(spectrum_data) return _spectrum_from_arrays( np.asarray(spectrum_data.wave_obs, dtype=float), np.asarray(spectrum_data.flux, dtype=float), np.asarray(spectrum_data.uncertainty(), dtype=float), float(spectrum_data.redshift), good, source=source, spectrum_data=spectrum_data, ) def _host_subtracted_spectrum( spectrum_data: Any, *, redshift: Optional[float], template_root: str, template_file: str, fit_range: Tuple[float, float], host_config: Optional[Any], source: str, ) -> Tuple[Spectrum, Spectrum, Any, Any, np.ndarray, np.ndarray, list]: from .host.config import default_config from .host.ppxf_host import ( prepare_desi_for_host_decomp, predict_host_sed, predict_host_sed_on_grid, run_ppxf_host_fit, ) from .host.templates import load_ppxf_npz_templates cfg = host_config or default_config() templates = load_ppxf_npz_templates(template_root=template_root, template_file=template_file) prep = prepare_desi_for_host_decomp( spectrum_data, redshift=redshift, fit_range=fit_range, line_mask_widths=cfg.line_mask_widths, broad_line_mask_widths=cfg.broad_line_mask_widths, observed_artifact_windows=cfg.observed_artifact_windows, max_native_gap_pixels=cfg.max_native_gap_pixels, systematic_error_floor_fraction=cfg.systematic_error_floor_fraction, ) host_fit = run_ppxf_host_fit( prep, templates, agn_powerlaw_slopes=cfg.agn_powerlaw_slopes, polynomial_degree=cfg.polynomial_degree, multiplicative_polynomial_degree=cfg.multiplicative_polynomial_degree, adaptive_broad_line_max_velocity=cfg.adaptive_broad_line_max_velocity, adaptive_line_residual_sigma=cfg.adaptive_line_residual_sigma, residual_clip_sigma=cfg.residual_clip_sigma, residual_clip_iterations=cfg.residual_clip_iterations, residual_clip_dilation_pixels=cfg.residual_clip_dilation_pixels, max_noise_rescale=cfg.max_noise_rescale, minimum_clean_fraction=cfg.minimum_clean_fraction, minimum_clean_pixels=cfg.minimum_clean_pixels, minimum_continuum_snr=cfg.minimum_continuum_snr, maximum_clipped_fraction=cfg.maximum_clipped_fraction, ) host_sed = predict_host_sed(host_fit) host_on_grid, grid_warnings = predict_host_sed_on_grid(host_sed, prep.wave_rest) host_warnings = list(host_fit.warnings) + list(host_sed.warnings) + list(grid_warnings) finite_host = np.isfinite(host_on_grid) host_subtracted_flux = np.asarray(prep.flux, dtype=float) - np.where(finite_host, host_on_grid, 0.0) total_spectrum = _spectrum_from_arrays( prep.wave_obs, prep.flux, prep.error, prep.redshift, np.isfinite(prep.wave_obs) & np.isfinite(prep.flux) & np.isfinite(prep.error) & (prep.error > 0), source=source, spectrum_data=spectrum_data, ) fit_spectrum = _spectrum_from_arrays( prep.wave_obs, host_subtracted_flux, prep.error, prep.redshift, np.isfinite(prep.wave_obs) & np.isfinite(host_subtracted_flux) & np.isfinite(prep.error) & (prep.error > 0) & finite_host, source=f"{source}; host_subtracted=ppxf_sed_grid", spectrum_data=spectrum_data, ) return total_spectrum, fit_spectrum, host_fit, host_sed, host_on_grid, host_subtracted_flux, host_warnings
[docs] def fit_with_optional_host_decomp( input_path: str, local_config: Optional[LocalFitConfig] = None, *, row_index: Optional[int] = None, redshift: Optional[float] = None, object_id: Optional[str] = None, run_host_decomp: bool = False, fit_kind: str = "local", template_root: str = "~/tools/ppxf_data", template_file: str = "spectra_emiles_9.0.npz", host_fit_range: Tuple[float, float] = (3600.0, 7000.0), host_config: Optional[Any] = None, galactic_extinction_config: Optional[GalacticExtinctionConfig] = None, global_config: Optional[GlobalContinuumConfig] = None, hbeta_config: Optional[HbetaComplexConfig] = None, mgii_config: Optional[MgIIComplexConfig] = None, halpha_config: Optional[HalphaComplexConfig] = None, lya_nv_config: Optional[LyaNVComplexConfig] = None, uncertainty_config: Optional[UncertaintyConfig] = None, complexes: Optional[Sequence[Union[str, ComplexRecipe]]] = None, ): """Read a spectrum, optionally subtract a pPXF host, then run qsospec. ``fit_kind`` may be ``"local"`` or ``"global"``. """ if fit_kind == "global": return fit_global_lines_workflow( input_path, row_index=row_index, redshift=redshift, object_id=object_id, run_host_decomp=run_host_decomp, template_root=template_root, template_file=template_file, host_fit_range=host_fit_range, host_config=host_config, galactic_extinction_config=galactic_extinction_config, global_config=global_config, hbeta_config=hbeta_config, mgii_config=mgii_config, halpha_config=halpha_config, lya_nv_config=lya_nv_config, uncertainty_config=uncertainty_config, complexes=complexes, ) if fit_kind != "local": raise ValueError("fit_kind must be 'local' or 'global'.") if local_config is None: raise ValueError("local_config is required when fit_kind='local'.") from .host.io import read_sparcli_spectrum spectrum_data = read_sparcli_spectrum( input_path, row_index=row_index, redshift=redshift, object_id=object_id, ) spectrum_data = correct_spectrum_data( spectrum_data, galactic_extinction_config ) source = f"{input_path}:row_index={row_index}" host_decomp_enabled, host_skip_reason = _host_decomp_decision( run_host_decomp, spectrum_data.redshift ) if host_decomp_enabled: total_spectrum, fit_spectrum, host_fit, host_sed, host_on_grid, host_subtracted_flux, host_warnings = ( _host_subtracted_spectrum( spectrum_data, redshift=float(spectrum_data.redshift), template_root=template_root, template_file=template_file, fit_range=host_fit_range, host_config=host_config, source=source, ) ) else: total_spectrum = _spectrum_from_spectrum_data(spectrum_data, source=source) fit_spectrum = total_spectrum host_fit = None host_sed = None host_on_grid = None host_subtracted_flux = None host_warnings = [] local_result = fit_local(fit_spectrum, local_config) metadata = { "input_path": input_path, "row_index": row_index, "object_id": object_id or spectrum_data.object_id or spectrum_data.targetid, "targetid": spectrum_data.targetid, "ra": spectrum_data.ra, "dec": spectrum_data.dec, "redshift": fit_spectrum.z, "fit_kind": fit_kind, "host_decomp_requested": bool(run_host_decomp), "host_decomp_enabled": host_decomp_enabled, "host_decomp_skip_reason": host_skip_reason, "host_model_source": "template_weighted_sed_on_quasar_grid" if host_decomp_enabled else None, "galactic_extinction": dict( spectrum_data.metadata.get("galactic_extinction", {}) ), } return HostWorkflowResult( total_spectrum=total_spectrum, fit_spectrum=fit_spectrum, local_result=local_result, host_decomp_enabled=host_decomp_enabled, host_fit=host_fit, host_sed=host_sed, host_model_on_quasar_grid=host_on_grid, host_subtracted_flux=host_subtracted_flux, host_warnings=host_warnings, metadata=metadata, )
def _summarize_mc_results( samples: Dict[str, list], n_requested: int, continuum_success_count: int, complex_success_counts: Dict[str, int], ) -> Dict[str, Any]: percentiles = {} for name, values in samples.items(): finite = np.asarray(values, dtype=float) finite = finite[np.isfinite(finite)] if finite.size: p16, p50, p84 = np.percentile(finite, [16.0, 50.0, 84.0]) percentiles[name] = {"p16": float(p16), "p50": float(p50), "p84": float(p84)} return { "n_requested": int(n_requested), "continuum_success_count": int(continuum_success_count), "complex_success_counts": dict(complex_success_counts), "percentiles": percentiles, } def _run_host_refit_mc( spectrum_data: Any, *, n_trials: int, seed: Optional[int], redshift: Optional[float], template_root: str, template_file: str, host_fit_range: Tuple[float, float], host_config: Optional[Any], source: str, global_config: Optional[GlobalContinuumConfig], hbeta_config: Optional[HbetaComplexConfig], mgii_config: Optional[MgIIComplexConfig], halpha_config: Optional[HalphaComplexConfig], lya_nv_config: Optional[LyaNVComplexConfig] = None, complexes: Optional[Sequence[Union[str, ComplexRecipe]]] = None, ) -> Dict[str, Any]: rng = np.random.default_rng(seed) samples: Dict[str, list] = {} continuum_successes = 0 complex_successes: Dict[str, int] = {} error = np.asarray(spectrum_data.uncertainty(), dtype=float) for _ in range(int(n_trials)): noisy_data = replace( spectrum_data, flux=np.asarray(spectrum_data.flux, dtype=float) + rng.normal(0.0, error), ) try: _, fit_spectrum, _, _, host_on_grid, _, _ = _host_subtracted_spectrum( noisy_data, redshift=redshift, template_root=template_root, template_file=template_file, fit_range=host_fit_range, host_config=host_config, source=source, ) trial = fit_global_lines( fit_spectrum, global_config, hbeta_config, mgii_config, halpha_config, UncertaintyConfig(monte_carlo_trials=0), lya_nv_config=lya_nv_config, host_model_on_grid=host_on_grid, complexes=complexes, ) values = {} if trial.continuum_success: continuum_successes += 1 values.update(trial.continuum.param_values) for recipe_id, complex_result in trial.line_complexes.items(): if complex_result.success: complex_successes[recipe_id] = complex_successes.get(recipe_id, 0) + 1 values.update(complex_result.metrics) for name, value in values.items(): if np.isfinite(value): samples.setdefault(name, []).append(float(value)) except Exception: continue return _summarize_mc_results( samples, n_trials, continuum_successes, complex_successes )
[docs] def fit_global_lines_workflow( input_path: str, *, row_index: Optional[int] = None, redshift: Optional[float] = None, object_id: Optional[str] = None, run_host_decomp: bool = False, template_root: str = "~/tools/ppxf_data", template_file: str = "spectra_emiles_9.0.npz", host_fit_range: Tuple[float, float] = (3600.0, 7000.0), host_config: Optional[Any] = None, galactic_extinction_config: Optional[GalacticExtinctionConfig] = None, global_config: Optional[GlobalContinuumConfig] = None, hbeta_config: Optional[HbetaComplexConfig] = None, mgii_config: Optional[MgIIComplexConfig] = None, halpha_config: Optional[HalphaComplexConfig] = None, lya_nv_config: Optional[LyaNVComplexConfig] = None, uncertainty_config: Optional[UncertaintyConfig] = None, complexes: Optional[Sequence[Union[str, ComplexRecipe]]] = None, ) -> WorkflowResult: """Read one spectrum and run optional pPXF plus global multi-line qsospec.""" from .host.io import read_sparcli_spectrum uncertainty = uncertainty_config or UncertaintyConfig() spectrum_data = read_sparcli_spectrum( input_path, row_index=row_index, redshift=redshift, object_id=object_id ) spectrum_data = correct_spectrum_data( spectrum_data, galactic_extinction_config ) source = f"{input_path}:row_index={row_index}" host_decomp_enabled, host_skip_reason = _host_decomp_decision( run_host_decomp, spectrum_data.redshift ) if host_decomp_enabled: total_spectrum, fit_spectrum, host_fit, host_sed, host_on_grid, _, host_warnings = ( _host_subtracted_spectrum( spectrum_data, redshift=float(spectrum_data.redshift), template_root=template_root, template_file=template_file, fit_range=host_fit_range, host_config=host_config, source=source, ) ) primary_uncertainty = ( replace(uncertainty, monte_carlo_trials=0) if uncertainty.monte_carlo_trials > 0 and uncertainty.refit_host_in_mc else uncertainty ) else: total_spectrum = _spectrum_from_spectrum_data(spectrum_data, source=source) fit_spectrum = total_spectrum host_fit = None host_sed = None host_on_grid = None host_warnings = [] primary_uncertainty = uncertainty workflow = fit_global_lines( fit_spectrum, global_config, hbeta_config, mgii_config, halpha_config, primary_uncertainty, lya_nv_config=lya_nv_config, host_model_on_grid=host_on_grid, complexes=complexes, ) workflow.host_decomp_enabled = host_decomp_enabled workflow.total_spectrum = total_spectrum workflow.host_fit = host_fit workflow.host_sed = host_sed workflow.host_model_on_quasar_grid = host_on_grid workflow.host_fit_mask = ( np.asarray(host_fit.preprocessed.fit_mask, dtype=bool).copy() if host_fit is not None else None ) workflow.host_emission_mask = ( np.asarray(host_fit.preprocessed.emission_mask, dtype=bool).copy() if host_fit is not None else None ) workflow.host_warnings = [str(item) for item in host_warnings] workflow.metadata.update( { "input_path": input_path, "row_index": row_index, "object_id": object_id or spectrum_data.object_id or spectrum_data.targetid, "targetid": spectrum_data.targetid, "ra": spectrum_data.ra, "dec": spectrum_data.dec, "redshift": fit_spectrum.z, "fit_kind": "global", "flux_frame": fit_spectrum.flux_frame, "rest_frame_conversion": dict( fit_spectrum.metadata.rest_frame_conversion ), "host_decomp_requested": bool(run_host_decomp), "host_decomp_enabled": host_decomp_enabled, "host_decomp_skip_reason": host_skip_reason, "host_model_source": "template_weighted_sed_on_quasar_grid" if host_decomp_enabled else None, "host_fit_range": list(host_fit_range), "host_mask_provenance": "exact" if host_decomp_enabled else "unavailable", "host_ppxf_status": ( host_fit.status if host_fit is not None else None ), "host_ppxf_reduced_chi2": ( float(host_fit.reduced_chi2) if host_fit is not None else None ), "host_fit_reliable": ( bool(host_fit.host_fit_reliable) if host_fit is not None else None ), "host_fit_reliability_reasons": ( list(host_fit.host_fit_reliability_reasons) if host_fit is not None else [] ), "host_fit_quality": ( dict(host_fit.quality_metrics) if host_fit is not None else {} ), "host_noise_rescale_factors": ( dict(host_fit.noise_rescale_factors) if host_fit is not None else {} ), "host_mask_components_log": ( { key: np.asarray(value, dtype=bool).tolist() for key, value in host_fit.preprocessed.mask_provenance.items() if str(key).endswith("_log") or str(key) == "log_grid_valid" } if host_fit is not None else {} ), "host_mask_component_counts": ( { key: int(np.count_nonzero(value)) for key, value in host_fit.preprocessed.mask_provenance.items() } if host_fit is not None else {} ), "host_template_file": ( host_fit.templates.source_path if host_fit is not None else None ), "host_template_wavelength_coverage": ( list(host_fit.templates.wavelength_coverage) if host_fit is not None else None ), "galactic_extinction": dict( spectrum_data.metadata.get("galactic_extinction", {}) ), } ) if run_host_decomp and not host_decomp_enabled: workflow.warnings.append( FitWarning( code="host_decomp_skipped_redshift", message="Host decomposition was requested but skipped by the redshift gate.", severity="info", context={ "redshift": spectrum_data.redshift, "threshold": 1.2, "reason": host_skip_reason, }, ) ) if host_decomp_enabled and uncertainty.monte_carlo_trials > 0 and uncertainty.refit_host_in_mc: workflow.monte_carlo = _run_host_refit_mc( spectrum_data, n_trials=uncertainty.monte_carlo_trials, seed=uncertainty.random_seed, redshift=redshift, template_root=template_root, template_file=template_file, host_fit_range=host_fit_range, host_config=host_config, source=source, global_config=global_config, hbeta_config=hbeta_config, mgii_config=mgii_config, halpha_config=halpha_config, lya_nv_config=lya_nv_config, complexes=complexes, ) workflow.metadata["uncertainty_mode"] = "covariance+monte_carlo_host_refit" return workflow
[docs] def fit_global_hbeta_workflow( input_path: str, *, row_index: Optional[int] = None, redshift: Optional[float] = None, object_id: Optional[str] = None, run_host_decomp: bool = False, template_root: str = "~/tools/ppxf_data", template_file: str = "spectra_emiles_9.0.npz", host_fit_range: Tuple[float, float] = (3600.0, 7000.0), host_config: Optional[Any] = None, galactic_extinction_config: Optional[GalacticExtinctionConfig] = None, global_config: Optional[GlobalContinuumConfig] = None, hbeta_config: Optional[HbetaComplexConfig] = None, uncertainty_config: Optional[UncertaintyConfig] = None, ) -> WorkflowResult: """Compatibility wrapper for :func:`fit_global_lines_workflow`.""" result = fit_global_lines_workflow( input_path, row_index=row_index, redshift=redshift, object_id=object_id, run_host_decomp=run_host_decomp, template_root=template_root, template_file=template_file, host_fit_range=host_fit_range, host_config=host_config, galactic_extinction_config=galactic_extinction_config, global_config=global_config, hbeta_config=hbeta_config, uncertainty_config=uncertainty_config, complexes=("hbeta_oiii",), ) result.metadata["compatibility_hbeta_mode"] = True return result