"""Foreground Galactic-extinction queries and dereddening helpers."""
from __future__ import annotations
from dataclasses import asdict, replace
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, Optional, Tuple
import astropy.units as u
import numpy as np
from astropy.coordinates import SkyCoord
from dust_extinction.parameter_averages import F99
from .config import GalacticExtinctionConfig
from .metadata import SpectrumMetadata
from .spectrum import Spectrum
from .workflows.host.io import SpectrumData
_PLANCK_GNILC_FILENAME = (
"COM_CompMap_Dust-GNILC-Model-Opacity_2048_R2.01.fits"
)
_PROVENANCE_KEY = "galactic_extinction"
_REST_FRAME_KEY = "rest_frame_conversion"
def _resolved_config(
config: Optional[GalacticExtinctionConfig],
) -> GalacticExtinctionConfig:
return config or GalacticExtinctionConfig()
def _canonical_map_name(name: str) -> str:
value = str(name).strip().lower()
return "planck" if value in ("planck", "planck16") else value
def _resolved_data_dir(config: GalacticExtinctionConfig) -> Optional[str]:
if config.dustmaps_data_dir is not None:
return str(Path(config.dustmaps_data_dir).expanduser().resolve())
if not config.enabled or config.ebv_override is not None:
return None
from dustmaps.std_paths import data_dir
return str(Path(data_dir()).expanduser().resolve())
def _config_provenance(config: GalacticExtinctionConfig) -> Dict[str, Any]:
values = asdict(config)
values["map_name"] = _canonical_map_name(config.map_name)
values["dustmaps_data_dir"] = _resolved_data_dir(config)
if (
config.enabled
and config.ebv_override is None
and values["dustmaps_data_dir"] is not None
):
map_root = Path(values["dustmaps_data_dir"])
values["map_path"] = str(
map_root / "planck" / _PLANCK_GNILC_FILENAME
if values["map_name"] == "planck"
else map_root / "sfd"
)
else:
values["map_path"] = None
return values
@lru_cache(maxsize=8)
def _dust_query(map_name: str, data_dir: Optional[str]):
if map_name == "planck":
from dustmaps.planck import PlanckGNILCQuery
map_fname = None
if data_dir is not None:
map_fname = str(
Path(data_dir)
/ "planck"
/ _PLANCK_GNILC_FILENAME
)
return PlanckGNILCQuery(map_fname=map_fname)
if map_name == "sfd":
from dustmaps.sfd import SFDQuery
map_dir = (
str(Path(data_dir) / "sfd")
if data_dir is not None
else None
)
return SFDQuery(map_dir=map_dir)
raise ValueError(f"Unsupported Galactic dust map: {map_name!r}")
[docs]
def preflight_galactic_extinction(
config: Optional[GalacticExtinctionConfig] = None,
) -> None:
"""Validate that the configured external dust map can be opened."""
cfg = _resolved_config(config)
if not cfg.enabled or cfg.ebv_override is not None:
return
_dust_query(
_canonical_map_name(cfg.map_name),
_resolved_data_dir(cfg),
)
[docs]
def query_galactic_ebv(
ra: Optional[float],
dec: Optional[float],
config: Optional[GalacticExtinctionConfig] = None,
) -> Tuple[float, Dict[str, Any]]:
"""Return applied E(B-V) and query provenance for one ICRS coordinate."""
cfg = _resolved_config(config)
provenance = _config_provenance(cfg)
provenance.update(
{
"requested": bool(cfg.enabled),
"ra": None if ra is None else float(ra),
"dec": None if dec is None else float(dec),
}
)
if not cfg.enabled:
provenance.update(
{
"applied": False,
"status": "disabled",
"raw_ebv": None,
"applied_ebv": 0.0,
"warning": None,
}
)
return 0.0, provenance
if cfg.ebv_override is not None:
raw_ebv = float(cfg.ebv_override)
source = "override"
else:
if ra is None or dec is None:
raise ValueError(
"Galactic extinction correction requires finite RA and Dec "
"in degrees, or GalacticExtinctionConfig.ebv_override."
)
ra_value = float(ra)
dec_value = float(dec)
if (
not np.isfinite(ra_value)
or not np.isfinite(dec_value)
or not 0.0 <= ra_value < 360.0
or not -90.0 <= dec_value <= 90.0
):
raise ValueError(
"Galactic extinction correction requires finite ICRS "
"coordinates with 0 <= RA < 360 and -90 <= Dec <= 90 degrees."
)
data_dir = _resolved_data_dir(cfg)
query = _dust_query(_canonical_map_name(cfg.map_name), data_dir)
coordinate = SkyCoord(
ra=ra_value * u.deg,
dec=dec_value * u.deg,
frame="icrs",
)
raw_ebv = float(np.asarray(query(coordinate)).reshape(-1)[0])
source = _canonical_map_name(cfg.map_name)
if not np.isfinite(raw_ebv):
raise ValueError("The Galactic dust-map query returned non-finite E(B-V).")
warning = None
map_ebv = raw_ebv
if raw_ebv < 0:
if not cfg.clip_negative_ebv:
raise ValueError(
"The Galactic dust-map query returned negative E(B-V) and "
"clip_negative_ebv is disabled."
)
map_ebv = 0.0
warning = "negative_ebv_clipped_to_zero"
applied_ebv = (
map_ebv * float(cfg.sfd_recalibration)
if source == "sfd"
else map_ebv
)
provenance.update(
{
"applied": True,
"status": "applied",
"source": source,
"raw_ebv": raw_ebv,
"applied_ebv": applied_ebv,
"warning": warning,
}
)
return float(applied_ebv), provenance
[docs]
def f99_dereddening_factor(
wave_obs: np.ndarray,
ebv: float,
rv: float = 3.1,
) -> np.ndarray:
"""Return the multiplicative F99 dereddening factor."""
wave = np.asarray(wave_obs, dtype=float)
if wave.ndim != 1:
raise ValueError("Observed wavelengths must be a one-dimensional array.")
if not np.all(np.isfinite(wave)) or np.any(wave <= 0):
raise ValueError(
"F99 Galactic extinction correction requires positive, finite "
"observed wavelengths."
)
model = F99(Rv=float(rv))
inverse_micron = (1.0 / (wave * u.AA)).to_value(1 / u.micron)
if (
np.any(inverse_micron < model.x_range[0])
or np.any(inverse_micron > model.x_range[1])
):
supported_min = 1.0e4 / model.x_range[1]
supported_max = 1.0e4 / model.x_range[0]
raise ValueError(
"Observed wavelengths fall outside the F99 supported range "
f"[{supported_min:.1f}, {supported_max:.1f}] Angstrom."
)
attenuation = np.asarray(
model.extinguish(wave * u.AA, Ebv=float(ebv)),
dtype=float,
)
return 1.0 / attenuation
def _same_correction(
existing: Dict[str, Any],
config: GalacticExtinctionConfig,
ra: Optional[float],
dec: Optional[float],
) -> bool:
requested = _config_provenance(config)
for key, value in requested.items():
if existing.get(key) != value:
return False
if config.ebv_override is None:
return existing.get("ra") == ra and existing.get("dec") == dec
return True
[docs]
def correct_spectrum_data(
spectrum: SpectrumData,
config: Optional[GalacticExtinctionConfig] = None,
) -> SpectrumData:
"""Apply the configured Galactic correction exactly once."""
cfg = _resolved_config(config)
metadata = dict(spectrum.metadata)
existing = metadata.get(_PROVENANCE_KEY)
if isinstance(existing, dict):
if existing.get("status") in (
"caller_preprocessed",
"declared_corrected",
):
return _prepare_spectrum_data_rest_frame(spectrum)
if existing.get("status") == "disabled":
if not cfg.enabled and _same_correction(
existing, cfg, spectrum.ra, spectrum.dec
):
return _prepare_spectrum_data_rest_frame(spectrum)
metadata.pop(_PROVENANCE_KEY, None)
spectrum = replace(spectrum, metadata=metadata)
existing = None
if isinstance(existing, dict):
if _same_correction(existing, cfg, spectrum.ra, spectrum.dec):
return _prepare_spectrum_data_rest_frame(spectrum)
raise ValueError(
"SpectrumData already contains a different Galactic-extinction "
"correction and the raw arrays are unavailable."
)
ebv, provenance = query_galactic_ebv(spectrum.ra, spectrum.dec, cfg)
if not cfg.enabled:
metadata[_PROVENANCE_KEY] = provenance
return _prepare_spectrum_data_rest_frame(
replace(spectrum, metadata=metadata)
)
factor = f99_dereddening_factor(spectrum.wave_obs, ebv, cfg.rv)
flux = np.asarray(spectrum.flux, dtype=float) * factor
error = (
None
if spectrum.error is None
else np.asarray(spectrum.error, dtype=float) * factor
)
ivar = (
None
if spectrum.ivar is None
else np.asarray(spectrum.ivar, dtype=float) / factor**2
)
provenance.update(
{
"correction_factor_min": float(np.min(factor)),
"correction_factor_max": float(np.max(factor)),
}
)
metadata[_PROVENANCE_KEY] = provenance
return _prepare_spectrum_data_rest_frame(
replace(
spectrum,
flux=flux,
error=error,
ivar=ivar,
metadata=metadata,
)
)
def _rest_frame_provenance(
redshift: float,
*,
input_frame: str,
status: str,
) -> Dict[str, Any]:
factor = 1.0 + float(redshift)
if not np.isfinite(factor) or factor <= 0:
raise ValueError("Rest-frame conversion requires finite z with 1 + z > 0.")
return {
"status": status,
"input_flux_frame": input_frame,
"output_flux_frame": "rest",
"redshift": float(redshift),
"flux_error_factor": float(factor),
"inverse_variance_factor": float(1.0 / factor**2),
}
def _metadata_with_rest_frame(
metadata: Dict[str, Any],
provenance: Dict[str, Any],
) -> Dict[str, Any]:
output = dict(metadata)
output["flux_frame"] = "rest"
output[_REST_FRAME_KEY] = dict(provenance)
nested = output.get("spectrum_metadata")
if isinstance(nested, dict):
nested = dict(nested)
nested["flux_frame"] = "rest"
nested[_REST_FRAME_KEY] = dict(provenance)
output["spectrum_metadata"] = nested
return output
def _prepare_spectrum_data_rest_frame(
spectrum: SpectrumData,
) -> SpectrumData:
metadata = dict(spectrum.metadata)
frame = str(metadata.get("flux_frame", "observed")).lower()
if spectrum.redshift is None or not np.isfinite(spectrum.redshift):
raise ValueError("Rest-frame conversion requires a finite redshift.")
if frame == "rest":
existing = metadata.get(_REST_FRAME_KEY)
if isinstance(existing, dict) and existing:
expected = _rest_frame_provenance(
float(spectrum.redshift),
input_frame="observed",
status="applied",
)
if (
existing.get("output_flux_frame") != "rest"
or not np.isclose(
float(existing.get("redshift", np.nan)),
expected["redshift"],
)
or not np.isclose(
float(existing.get("flux_error_factor", np.nan)),
expected["flux_error_factor"],
)
):
raise ValueError(
"SpectrumData contains conflicting rest-frame conversion "
"provenance."
)
return spectrum
if frame != "observed":
raise ValueError("SpectrumData flux_frame must be 'observed' or 'rest'.")
if metadata.get(_REST_FRAME_KEY):
raise ValueError(
"SpectrumData is marked observed but already contains rest-frame "
"conversion provenance."
)
provenance = _rest_frame_provenance(
float(spectrum.redshift),
input_frame="observed",
status="applied",
)
factor = provenance["flux_error_factor"]
return replace(
spectrum,
flux=np.asarray(spectrum.flux, dtype=float) * factor,
error=(
None
if spectrum.error is None
else np.asarray(spectrum.error, dtype=float) * factor
),
ivar=(
None
if spectrum.ivar is None
else np.asarray(spectrum.ivar, dtype=float) / factor**2
),
metadata=_metadata_with_rest_frame(metadata, provenance),
)
def _updated_spectrum_metadata(
metadata: SpectrumMetadata,
*,
ra: Optional[float],
dec: Optional[float],
corrected: bool,
provenance: Dict[str, Any],
) -> SpectrumMetadata:
return replace(
metadata,
ra=ra,
dec=dec,
galactic_extinction_corrected=bool(corrected),
galactic_extinction=dict(provenance),
notes=list(metadata.notes),
)
def _prepare_spectrum_rest_frame(spectrum: Spectrum) -> Spectrum:
if spectrum.flux_frame == "rest":
existing = spectrum.metadata.rest_frame_conversion
if existing:
expected = _rest_frame_provenance(
spectrum.z,
input_frame="observed",
status="applied",
)
if (
existing.get("output_flux_frame") != "rest"
or not np.isclose(
float(existing.get("redshift", np.nan)),
expected["redshift"],
)
or not np.isclose(
float(existing.get("flux_error_factor", np.nan)),
expected["flux_error_factor"],
)
):
raise ValueError(
"Spectrum contains conflicting rest-frame conversion "
"provenance."
)
return spectrum
if spectrum.flux_frame != "observed":
raise ValueError("Spectrum flux_frame must be 'observed' or 'rest'.")
if spectrum.metadata.rest_frame_conversion:
raise ValueError(
"Spectrum is marked observed but already contains rest-frame "
"conversion provenance."
)
provenance = _rest_frame_provenance(
spectrum.z,
input_frame="observed",
status="applied",
)
factor = provenance["flux_error_factor"]
metadata = replace(
spectrum.metadata,
flux_frame="rest",
rest_frame_conversion=dict(provenance),
notes=list(spectrum.metadata.notes),
)
return replace(
spectrum,
flux=np.asarray(spectrum.flux, dtype=float) * factor,
err=np.asarray(spectrum.err, dtype=float) * factor,
metadata=metadata,
)
[docs]
def prepare_spectrum(
spectrum: Spectrum,
*,
galactic_extinction_config: Optional[
GalacticExtinctionConfig
] = None,
) -> Spectrum:
"""Prepare an in-memory spectrum for fitting.
Array spectra are assumed to be uncorrected unless constructed with
``galactic_extinction_corrected=True``. Correction is performed in the
observed frame, then flux and uncertainty are normalized to rest-frame
F_lambda. Both operations are recorded and applied exactly once.
"""
cfg = _resolved_config(galactic_extinction_config)
ra = spectrum.metadata.ra
dec = spectrum.metadata.dec
existing = dict(spectrum.metadata.galactic_extinction)
status = existing.get("status")
if status in ("declared_corrected", "caller_preprocessed"):
metadata = _updated_spectrum_metadata(
spectrum.metadata,
ra=ra,
dec=dec,
corrected=True,
provenance=existing,
)
prepared = (
spectrum
if metadata == spectrum.metadata
else replace(spectrum, metadata=metadata)
)
return _prepare_spectrum_rest_frame(prepared)
if status == "applied":
if not _same_correction(existing, cfg, ra, dec):
raise ValueError(
"Spectrum already contains a different Galactic-extinction "
"correction and the raw arrays are unavailable."
)
metadata = _updated_spectrum_metadata(
spectrum.metadata,
ra=ra,
dec=dec,
corrected=True,
provenance=existing,
)
prepared = (
spectrum
if metadata == spectrum.metadata
else replace(spectrum, metadata=metadata)
)
return _prepare_spectrum_rest_frame(prepared)
if spectrum.metadata.galactic_extinction_corrected:
provenance = _config_provenance(cfg)
provenance.update(
{
"requested": bool(cfg.enabled),
"applied": False,
"status": "declared_corrected",
"ra": ra,
"dec": dec,
"raw_ebv": None,
"applied_ebv": None,
"warning": None,
}
)
return _prepare_spectrum_rest_frame(
replace(
spectrum,
metadata=_updated_spectrum_metadata(
spectrum.metadata,
ra=ra,
dec=dec,
corrected=True,
provenance=provenance,
),
)
)
if not cfg.enabled:
_, provenance = query_galactic_ebv(ra, dec, cfg)
return _prepare_spectrum_rest_frame(
replace(
spectrum,
metadata=_updated_spectrum_metadata(
spectrum.metadata,
ra=ra,
dec=dec,
corrected=False,
provenance=provenance,
),
)
)
try:
ebv, provenance = query_galactic_ebv(ra, dec, cfg)
except ValueError as exc:
if "requires finite RA and Dec" in str(exc):
raise ValueError(
"This array spectrum is marked as not corrected for Galactic "
"extinction. Supply finite ra and dec to Spectrum.from_arrays, "
"set GalacticExtinctionConfig(ebv_override=...), or set "
"galactic_extinction_corrected=True when the supplied arrays "
"have already been dereddened."
) from exc
raise
factor = f99_dereddening_factor(spectrum.wave_obs, ebv, cfg.rv)
provenance.update(
{
"correction_factor_min": float(np.min(factor)),
"correction_factor_max": float(np.max(factor)),
}
)
return _prepare_spectrum_rest_frame(
replace(
spectrum,
flux=np.asarray(spectrum.flux, dtype=float) * factor,
err=np.asarray(spectrum.err, dtype=float) * factor,
metadata=_updated_spectrum_metadata(
spectrum.metadata,
ra=ra,
dec=dec,
corrected=True,
provenance=provenance,
),
)
)
[docs]
def correct_spectrum(
spectrum: Spectrum,
*,
ra: Optional[float] = None,
dec: Optional[float] = None,
config: Optional[GalacticExtinctionConfig] = None,
) -> Tuple[Spectrum, Dict[str, Any]]:
"""Return a corrected in-memory spectrum and correction provenance."""
metadata = replace(
spectrum.metadata,
ra=spectrum.metadata.ra if ra is None else float(ra),
dec=spectrum.metadata.dec if dec is None else float(dec),
galactic_extinction_corrected=False,
galactic_extinction={},
notes=list(spectrum.metadata.notes),
)
corrected = prepare_spectrum(
replace(spectrum, metadata=metadata),
galactic_extinction_config=config,
)
return corrected, dict(corrected.metadata.galactic_extinction)