"""Parquet-backed immutable run bundles for qsospec results."""
from __future__ import annotations
from dataclasses import asdict, dataclass, field, is_dataclass
from datetime import datetime, timezone
from contextlib import contextmanager
import hashlib
from importlib.metadata import PackageNotFoundError, version
import json
import os
from pathlib import Path
import shutil
import subprocess
from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Union
from uuid import uuid4
import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.dataset as pads
import pyarrow.parquet as pq
from ..global_result import (
EmissionComplexResult,
GlobalContinuumResult,
WorkflowResult,
)
from ..metadata import resolve_spectrum_metadata
from ..spectrum import Spectrum
from ..warnings import FitWarning
SCHEMA_VERSION = "5"
TABLE_NAMES = (
"inputs",
"objects",
"measurements",
"warnings",
"models",
"failures",
"derived",
)
def _now() -> str:
return datetime.now(timezone.utc).isoformat()
def _jsonable(value: Any) -> Any:
if is_dataclass(value):
return _jsonable(asdict(value))
if isinstance(value, Mapping):
return {
str(key): _jsonable(item)
for key, item in sorted(value.items(), key=lambda pair: str(pair[0]))
}
if isinstance(value, (list, tuple)):
return [_jsonable(item) for item in value]
if isinstance(value, np.ndarray):
return value.tolist()
if isinstance(value, np.generic):
return value.item()
if isinstance(value, Path):
return str(value)
if callable(value):
return getattr(value, "__qualname__", repr(value))
if value is None or isinstance(value, (str, int, float, bool)):
return value
return repr(value)
def configuration_hash(configuration: Mapping[str, Any]) -> str:
"""Return a stable SHA-256 hash for a run configuration."""
payload = json.dumps(
_jsonable(configuration),
sort_keys=True,
separators=(",", ":"),
allow_nan=True,
).encode("utf-8")
return hashlib.sha256(payload).hexdigest()
def _package_version() -> str:
try:
return version("qsospec")
except PackageNotFoundError:
return "unknown"
def _git_commit(root: Path) -> Optional[str]:
try:
return subprocess.check_output(
["git", "rev-parse", "HEAD"],
cwd=root,
text=True,
stderr=subprocess.DEVNULL,
).strip()
except Exception:
return None
def _key_values(mapping: Mapping[str, Any]) -> list[dict[str, str]]:
return [
{
"key": str(key),
"value": json.dumps(_jsonable(value), sort_keys=True, allow_nan=True),
}
for key, value in sorted(mapping.items(), key=lambda item: str(item[0]))
]
def _from_key_values(items: Optional[Sequence[Mapping[str, str]]]) -> Dict[str, Any]:
output: Dict[str, Any] = {}
for item in items or ():
try:
output[str(item["key"])] = json.loads(item["value"])
except Exception:
output[str(item["key"])] = item.get("value")
return output
KEY_VALUE_TYPE = pa.list_(
pa.struct([pa.field("key", pa.string()), pa.field("value", pa.string())])
)
COMPONENT_TYPE = pa.list_(
pa.struct(
[
pa.field("section", pa.string()),
pa.field("recipe_id", pa.string()),
pa.field("name", pa.string()),
pa.field("role", pa.string()),
pa.field("values", pa.list_(pa.float64())),
]
)
)
COMPLEX_TYPE = pa.list_(
pa.struct(
[
pa.field("recipe_id", pa.string()),
pa.field("success", pa.bool_()),
pa.field("status", pa.int32()),
pa.field("message", pa.string()),
pa.field("selected_model", pa.string()),
pa.field("chi2", pa.float64()),
pa.field("dof", pa.int64()),
pa.field("reduced_chi2", pa.float64()),
pa.field("bic", pa.float64()),
pa.field("fit_mask", pa.list_(pa.bool_())),
pa.field("excluded_mask", pa.list_(pa.bool_())),
pa.field("metadata", KEY_VALUE_TYPE),
]
)
)
SCHEMAS = {
"inputs": pa.schema(
[
pa.field("run_id", pa.string()),
pa.field("object_key", pa.string()),
pa.field("object_id", pa.string()),
pa.field("source", pa.string()),
pa.field("row_index", pa.int64()),
pa.field("reader", pa.string()),
pa.field("redshift", pa.float64()),
pa.field("metadata", KEY_VALUE_TYPE),
]
),
"objects": pa.schema(
[
pa.field("run_id", pa.string()),
pa.field("object_key", pa.string()),
pa.field("object_id", pa.string()),
pa.field("redshift", pa.float64()),
pa.field("ra", pa.float64()),
pa.field("dec", pa.float64()),
pa.field("continuum_success", pa.bool_()),
pa.field("continuum_reduced_chi2", pa.float64()),
pa.field("host_decomp_enabled", pa.bool_()),
pa.field("complex_statuses", KEY_VALUE_TYPE),
pa.field("warning_codes", pa.list_(pa.string())),
pa.field("metadata", KEY_VALUE_TYPE),
pa.field("completed_at", pa.string()),
]
),
"measurements": pa.schema(
[
pa.field("run_id", pa.string()),
pa.field("object_key", pa.string()),
pa.field("object_id", pa.string()),
pa.field("section", pa.string()),
pa.field("recipe_id", pa.string()),
pa.field("feature_id", pa.string()),
pa.field("role", pa.string()),
pa.field("quantity", pa.string()),
pa.field("value", pa.float64()),
pa.field("error", pa.float64()),
pa.field("unit", pa.string()),
pa.field("method", pa.string()),
pa.field("metadata", KEY_VALUE_TYPE),
]
),
"warnings": pa.schema(
[
pa.field("run_id", pa.string()),
pa.field("object_key", pa.string()),
pa.field("object_id", pa.string()),
pa.field("section", pa.string()),
pa.field("recipe_id", pa.string()),
pa.field("code", pa.string()),
pa.field("severity", pa.string()),
pa.field("message", pa.string()),
pa.field("context", KEY_VALUE_TYPE),
]
),
"models": pa.schema(
[
pa.field("run_id", pa.string()),
pa.field("object_key", pa.string()),
pa.field("object_id", pa.string()),
pa.field("redshift", pa.float64()),
pa.field("wave_rest", pa.list_(pa.float64())),
pa.field("flux", pa.list_(pa.float64())),
pa.field("error", pa.list_(pa.float64())),
pa.field("input_mask", pa.list_(pa.bool_())),
pa.field("total_flux", pa.list_(pa.float64())),
pa.field("host_model", pa.list_(pa.float64())),
pa.field("host_fit_mask", pa.list_(pa.bool_())),
pa.field("host_emission_mask", pa.list_(pa.bool_())),
pa.field("continuum_fit_mask", pa.list_(pa.bool_())),
pa.field("continuum_clip_mask", pa.list_(pa.bool_())),
pa.field("components", COMPONENT_TYPE),
pa.field("complexes", COMPLEX_TYPE),
pa.field("spectrum_metadata", KEY_VALUE_TYPE),
pa.field("workflow_metadata", KEY_VALUE_TYPE),
]
),
"failures": pa.schema(
[
pa.field("run_id", pa.string()),
pa.field("object_key", pa.string()),
pa.field("object_id", pa.string()),
pa.field("source", pa.string()),
pa.field("row_index", pa.int64()),
pa.field("exception_type", pa.string()),
pa.field("message", pa.string()),
pa.field("traceback", pa.string()),
pa.field("failed_at", pa.string()),
pa.field("metadata", KEY_VALUE_TYPE),
]
),
"derived": pa.schema(
[
pa.field("run_id", pa.string()),
pa.field("object_key", pa.string()),
pa.field("object_id", pa.string()),
pa.field("quantity", pa.string()),
pa.field("calibration_id", pa.string()),
pa.field("value", pa.float64()),
pa.field("statistical_error", pa.float64()),
pa.field("intrinsic_scatter", pa.float64()),
pa.field("total_error", pa.float64()),
pa.field("unit", pa.string()),
pa.field("metadata", KEY_VALUE_TYPE),
]
),
}
def _float(value: Any) -> Optional[float]:
try:
output = float(value)
except (TypeError, ValueError):
return None
return output
def _feature_and_role(name: str) -> tuple[Optional[str], Optional[str]]:
lowered = name.lower()
roles = ("very_broad", "broad", "narrow", "wing", "blend")
for role in roles:
token = f"_{role}_"
if token in lowered:
index = lowered.index(token)
return name[:index], role
return None, None
def _measurement_rows(
result: WorkflowResult,
run_id: str,
object_key: str,
object_id: str,
) -> list[dict[str, Any]]:
rows = []
def add(
section: str,
recipe_id: Optional[str],
values: Mapping[str, Any],
errors: Mapping[str, Any],
method: str,
) -> None:
for quantity, value in values.items():
numeric = _float(value)
if numeric is None:
continue
feature_id, role = _feature_and_role(str(quantity))
rows.append(
{
"run_id": run_id,
"object_key": object_key,
"object_id": object_id,
"section": section,
"recipe_id": recipe_id,
"feature_id": feature_id,
"role": role,
"quantity": str(quantity),
"value": numeric,
"error": _float(errors.get(quantity)),
"unit": None,
"method": method,
"metadata": [],
}
)
add(
"continuum_parameter",
None,
result.continuum.param_values,
result.continuum.param_errors,
"covariance",
)
for quantity in (
"balmer_pseudocontinuum_implied_hbeta_flux_input",
"balmer_pseudocontinuum_implied_hbeta_flux_cgs",
"balmer_pseudocontinuum_fwhm_kms",
"balmer_pseudocontinuum_velocity_kms",
"balmer_pseudocontinuum_edge_flux_density_input",
):
if quantity in result.continuum.metadata:
add(
"continuum_metric",
None,
{quantity: result.continuum.metadata[quantity]},
{},
"fit_metadata",
)
for recipe_id, fit in result.line_complexes.items():
add(
"complex_parameter",
recipe_id,
fit.param_values,
fit.param_errors,
"covariance",
)
add(
"complex_metric",
recipe_id,
fit.metrics,
fit.metric_errors,
"covariance",
)
for quantity, value in result.metadata.get("continuum_samples", {}).items():
add("continuum_sample", None, {quantity: value}, {}, "interpolation")
return rows
def _warning_rows(
result: WorkflowResult,
run_id: str,
object_key: str,
object_id: str,
) -> list[dict[str, Any]]:
rows = []
def add(section, recipe_id, warnings):
for warning in warnings:
rows.append(
{
"run_id": run_id,
"object_key": object_key,
"object_id": object_id,
"section": section,
"recipe_id": recipe_id,
"code": warning.code,
"severity": warning.severity,
"message": warning.message,
"context": _key_values(warning.context),
}
)
add("workflow", None, result.warnings)
add("continuum", None, result.continuum.warnings)
for recipe_id, fit in result.line_complexes.items():
add("complex", recipe_id, fit.warnings)
for message in result.host_warnings:
rows.append(
{
"run_id": run_id,
"object_key": object_key,
"object_id": object_id,
"section": "host",
"recipe_id": None,
"code": "host_warning",
"severity": "warning",
"message": str(message),
"context": [],
}
)
return rows
def _component_role(name: str) -> str:
lowered = name.lower()
for role in ("very_broad", "broad", "narrow", "wing", "blend"):
if role in lowered:
return role
return "continuum"
def _model_row(
result: WorkflowResult,
run_id: str,
object_key: str,
object_id: str,
) -> dict[str, Any]:
components = [
{
"section": "continuum",
"recipe_id": None,
"name": name,
"role": "continuum",
"values": np.asarray(values, dtype=float).tolist(),
}
for name, values in result.continuum.component_models.items()
]
complexes = []
for recipe_id, fit in result.line_complexes.items():
components.extend(
{
"section": "complex",
"recipe_id": recipe_id,
"name": name,
"role": _component_role(name),
"values": np.asarray(values, dtype=float).tolist(),
}
for name, values in fit.component_models.items()
)
complexes.append(
{
"recipe_id": recipe_id,
"success": bool(fit.success),
"status": int(fit.status),
"message": str(fit.message),
"selected_model": str(fit.selected_model),
"chi2": float(fit.chi2),
"dof": int(fit.dof),
"reduced_chi2": float(fit.reduced_chi2),
"bic": float(fit.bic),
"fit_mask": np.asarray(fit.fit_mask, dtype=bool).tolist(),
"excluded_mask": (
np.asarray(fit.excluded_mask, dtype=bool).tolist()
if fit.excluded_mask is not None
else np.zeros_like(result.spectrum.valid_mask).tolist()
),
"metadata": _key_values(fit.metadata),
}
)
return {
"run_id": run_id,
"object_key": object_key,
"object_id": object_id,
"redshift": float(result.spectrum.z),
"wave_rest": np.asarray(
result.spectrum.wave_rest, dtype=float
).tolist(),
"flux": np.asarray(result.spectrum.flux, dtype=float).tolist(),
"error": np.asarray(result.spectrum.err, dtype=float).tolist(),
"input_mask": (
np.asarray(result.spectrum.mask, dtype=bool).tolist()
if result.spectrum.mask is not None else None
),
"total_flux": (
np.asarray(result.total_spectrum.flux, dtype=float).tolist()
if result.total_spectrum is not None else None
),
"host_model": (
np.asarray(result.host_model_on_quasar_grid, dtype=float).tolist()
if result.host_model_on_quasar_grid is not None else None
),
"host_fit_mask": (
np.asarray(result.host_fit_mask, dtype=bool).tolist()
if result.host_fit_mask is not None else None
),
"host_emission_mask": (
np.asarray(result.host_emission_mask, dtype=bool).tolist()
if result.host_emission_mask is not None else None
),
"continuum_fit_mask": np.asarray(
result.continuum.fit_mask, dtype=bool
).tolist(),
"continuum_clip_mask": np.asarray(
result.continuum.clip_mask, dtype=bool
).tolist(),
"components": components,
"complexes": complexes,
"spectrum_metadata": _key_values(result.spectrum.metadata.to_dict()),
"workflow_metadata": _key_values(result.metadata),
}
def workflow_payload(
result: WorkflowResult,
*,
run_id: str,
object_key: str,
object_id: str,
input_record: Mapping[str, Any],
) -> Dict[str, list[dict[str, Any]]]:
"""Serialize one workflow into all authoritative run tables."""
metadata = dict(result.metadata)
return {
"inputs": [
{
"run_id": run_id,
"object_key": object_key,
"object_id": object_id,
"source": str(input_record.get("source", "")),
"row_index": input_record.get("row_index"),
"reader": str(input_record.get("reader", "auto")),
"redshift": float(result.spectrum.z),
"metadata": _key_values(input_record.get("metadata", {})),
}
],
"objects": [
{
"run_id": run_id,
"object_key": object_key,
"object_id": object_id,
"redshift": float(result.spectrum.z),
"ra": _float(metadata.get("ra")),
"dec": _float(metadata.get("dec")),
"continuum_success": bool(result.continuum_success),
"continuum_reduced_chi2": float(
result.continuum.reduced_chi2
),
"host_decomp_enabled": bool(result.host_decomp_enabled),
"complex_statuses": _key_values(result.complex_statuses),
"warning_codes": list(result.warning_codes()),
"metadata": _key_values(metadata),
"completed_at": _now(),
}
],
"measurements": _measurement_rows(
result, run_id, object_key, object_id
),
"warnings": _warning_rows(result, run_id, object_key, object_id),
"models": [_model_row(result, run_id, object_key, object_id)],
"failures": [],
"derived": [],
}
def _empty_table(name: str) -> pa.Table:
return pa.Table.from_pylist([], schema=SCHEMAS[name])
@dataclass
class RunStore:
"""Opened immutable run bundle."""
path: Path
manifest: Dict[str, Any] = field(default_factory=dict)
@property
def run_id(self) -> str:
return str(self.manifest["run_id"])
@property
def configuration_hash(self) -> str:
return str(self.manifest["configuration_hash"])
def plot_qa(self, identifier: str, plot_config=None):
"""Load one archived object and return its open QA figure."""
return load_model(self, identifier).plot_qa(plot_config)
def show_qa(self, identifier: str, plot_config=None):
"""Display and return one archived object's QA figure."""
return load_model(self, identifier).show_qa(plot_config)
@classmethod
def create(
cls,
path: str,
*,
configuration: Mapping[str, Any],
configuration_summary: Optional[Mapping[str, Any]] = None,
run_id: Optional[str] = None,
resume: bool = True,
) -> "RunStore":
root = Path(path).expanduser()
root.mkdir(parents=True, exist_ok=True)
manifest_path = root / "manifest.json"
config_hash = configuration_hash(configuration)
if manifest_path.exists():
manifest = json.loads(manifest_path.read_text(encoding="utf-8"))
cls._require_current_schema(manifest)
if manifest.get("configuration_hash") != config_hash:
raise ValueError(
"Run configuration does not match the immutable manifest. "
"Choose a new run directory or run_id."
)
if not resume:
raise FileExistsError(f"Run already exists: {root}")
store = cls(root, manifest)
store._ensure_directories()
return store
actual_run_id = run_id or (
datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
+ "-"
+ config_hash[:10]
)
manifest = {
"schema_version": SCHEMA_VERSION,
"run_id": actual_run_id,
"configuration_hash": config_hash,
"configuration": _jsonable(
configuration_summary
if configuration_summary is not None
else configuration
),
"cosmology": _jsonable(configuration.get("cosmology")),
"units": {
"wavelength": "vacuum Angstrom",
"velocity": "km/s",
"model_arrays": (
"rest-frame F_lambda in input flux-density scaling"
),
},
"created_at": _now(),
"updated_at": _now(),
"package_version": _package_version(),
"git_commit": _git_commit(Path(__file__).resolve().parents[3]),
"status": "active",
"tables": list(TABLE_NAMES),
"completed_objects": 0,
"failed_objects": 0,
}
store = cls(root, manifest)
store._ensure_directories()
store._write_manifest()
return store
@classmethod
def open(cls, path: str) -> "RunStore":
root = Path(path).expanduser()
manifest = json.loads((root / "manifest.json").read_text(encoding="utf-8"))
cls._require_current_schema(manifest)
return cls(root, manifest)
@staticmethod
def _require_current_schema(manifest: Mapping[str, Any]) -> None:
found = str(manifest.get("schema_version", "missing"))
if found != SCHEMA_VERSION:
raise ValueError(
"Unsupported qsospec run schema "
f"{found!r}; this version requires schema {SCHEMA_VERSION}. "
"Recreate the run with the current package."
)
def _ensure_directories(self) -> None:
(self.path / "qa").mkdir(exist_ok=True)
self._staging_path().mkdir(exist_ok=True)
def _table_path(self, name: str) -> Path:
return self.path / "data" / name
def _staging_path(self) -> Path:
return self.path / ".staging"
def _write_manifest(self) -> None:
with self._manifest_lock():
manifest_path = self.path / "manifest.json"
if manifest_path.exists():
existing = json.loads(manifest_path.read_text(encoding="utf-8"))
existing.update(self.manifest)
self.manifest = existing
self.manifest["updated_at"] = _now()
if self._table_path("objects").exists():
self.manifest["completed_objects"] = len(self.completed_keys())
self.manifest["failed_objects"] = len(self.failed_keys())
self.manifest["shard_state"] = {
name: len(
tuple(self._table_path(name).glob("*.parquet"))
)
for name in TABLE_NAMES
}
temporary = self.path / f"manifest.{uuid4().hex}.tmp"
temporary.write_text(
json.dumps(self.manifest, indent=2, sort_keys=True),
encoding="utf-8",
)
os.replace(temporary, manifest_path)
@contextmanager
def _manifest_lock(self):
lock_path = self.path / ".manifest.lock"
handle = lock_path.open("a+")
try:
try:
import fcntl
fcntl.flock(handle.fileno(), fcntl.LOCK_EX)
except ImportError:
pass
yield
finally:
try:
import fcntl
fcntl.flock(handle.fileno(), fcntl.LOCK_UN)
except ImportError:
pass
handle.close()
def completed_keys(self) -> set[str]:
table = self.read_table("objects")
return set(table.column("object_key").to_pylist()) if table.num_rows else set()
def failed_keys(self) -> set[str]:
table = self.read_table("failures")
return set(table.column("object_key").to_pylist()) if table.num_rows else set()
def clear_failure(self, object_key: str) -> None:
digest = hashlib.sha256(object_key.encode("utf-8")).hexdigest()[:20]
path = self._table_path("failures") / f"part-{digest}.parquet"
if path.exists():
path.unlink()
def stage_payload(
self,
payload: Mapping[str, Sequence[Mapping[str, Any]]],
*,
namespace: Optional[str] = None,
) -> Path:
"""Write one private, collision-free staging directory."""
namespace = namespace or f"{os.getpid()}-{uuid4().hex}"
staging = self._staging_path() / namespace
staging.mkdir(parents=True, exist_ok=False)
object_keys = {
str(row["object_key"])
for rows in payload.values()
for row in rows
if row.get("object_key") is not None
}
checksums = {}
for name in TABLE_NAMES:
rows = list(payload.get(name, ()))
if not rows:
continue
table = pa.Table.from_pylist(rows, schema=SCHEMAS[name])
output = staging / f"{name}.parquet"
pq.write_table(table, output, compression="zstd")
checksums[output.name] = hashlib.sha256(output.read_bytes()).hexdigest()
(staging / "staging.json").write_text(
json.dumps(
{
"object_keys": sorted(object_keys),
"replace_tables": sorted(
name for name in payload if name in TABLE_NAMES
),
"checksums": checksums,
}
),
encoding="utf-8",
)
return staging
def promote(self, staging: Union[str, Path]) -> Dict[str, str]:
"""Validate and atomically promote a worker staging directory."""
source = Path(staging)
promoted: Dict[str, str] = {}
staging_metadata = json.loads(
(source / "staging.json").read_text(encoding="utf-8")
)
object_keys = staging_metadata.get("object_keys", [])
replace_tables = set(staging_metadata.get("replace_tables", ()))
checksums = staging_metadata.get("checksums", {})
for filename, expected in checksums.items():
actual = hashlib.sha256((source / filename).read_bytes()).hexdigest()
if actual != expected:
raise ValueError(f"Staged shard checksum mismatch: {filename}")
if len(object_keys) == 1:
digest = hashlib.sha256(
object_keys[0].encode("utf-8")
).hexdigest()[:20]
for name in replace_tables:
old_path = self._table_path(name) / f"part-{digest}.parquet"
if old_path.exists():
old_path.unlink()
for file_path in sorted(source.glob("*.parquet")):
name = file_path.stem
if name not in SCHEMAS:
raise ValueError(f"Unknown staged table: {name}")
table = pq.read_table(file_path)
if not table.schema.equals(SCHEMAS[name], check_metadata=False):
table = table.cast(SCHEMAS[name])
pq.write_table(table, file_path, compression="zstd")
object_key = (
table.column("object_key")[0].as_py()
if table.num_rows else uuid4().hex
)
digest = hashlib.sha256(object_key.encode("utf-8")).hexdigest()[:20]
destination = self._table_path(name) / f"part-{digest}.parquet"
destination.parent.mkdir(parents=True, exist_ok=True)
os.replace(file_path, destination)
promoted[name] = str(destination)
shutil.rmtree(source, ignore_errors=True)
self._write_manifest()
return promoted
def write_payload(
self,
payload: Mapping[str, Sequence[Mapping[str, Any]]],
) -> Dict[str, str]:
return self.promote(self.stage_payload(payload))
def read_table(
self,
name: str,
*,
columns: Optional[Sequence[str]] = None,
filter_expression: Any = None,
) -> pa.Table:
if name not in SCHEMAS:
raise ValueError(f"Unknown run table: {name!r}")
files = sorted(self._table_path(name).glob("*.parquet"))
if not files:
table = _empty_table(name)
return table.select(columns) if columns else table
dataset = pads.dataset([str(path) for path in files], format="parquet")
return dataset.to_table(columns=columns, filter=filter_expression)
def object_row(
self,
identifier: str,
*,
table_name: str = "models",
) -> Mapping[str, Any]:
table = self.read_table(table_name)
matches = [
row
for row in table.to_pylist()
if row["object_key"] == identifier or row["object_id"] == identifier
]
if not matches:
raise KeyError(f"Object not found in {table_name}: {identifier!r}")
if len(matches) > 1:
raise ValueError(
f"Object identifier is ambiguous; use object_key: {identifier!r}"
)
return matches[0]
[docs]
def open_run(path: str) -> RunStore:
"""Open an existing run bundle."""
return RunStore.open(path)
def _measurement_maps(
store: RunStore,
object_key: str,
section: str,
recipe_id: Optional[str],
) -> tuple[Dict[str, float], Dict[str, float]]:
rows = [
row
for row in store.read_table("measurements").to_pylist()
if row["object_key"] == object_key
and row["section"] == section
and row["recipe_id"] == recipe_id
]
values = {row["quantity"]: row["value"] for row in rows}
errors = {row["quantity"]: row["error"] for row in rows}
return values, errors
def _archived_host_masks(
row: Mapping[str, Any],
wave_rest: np.ndarray,
) -> tuple[Optional[np.ndarray], Optional[np.ndarray], str]:
"""Return exact schema-v5 pPXF masks for one object."""
fit_values = row.get("host_fit_mask")
emission_values = row.get("host_emission_mask")
if fit_values is not None and emission_values is not None:
fit_mask = np.asarray(fit_values, dtype=bool)
emission_mask = np.asarray(emission_values, dtype=bool)
if fit_mask.shape == wave_rest.shape and emission_mask.shape == wave_rest.shape:
return fit_mask, emission_mask, "exact"
return None, None, "unavailable"
[docs]
def load_model(run: Union[str, RunStore], identifier: str) -> WorkflowResult:
"""Reconstruct a workflow result from the Parquet model archive."""
store = open_run(run) if isinstance(run, str) else run
row = store.object_row(identifier, table_name="models")
object_row = store.object_row(row["object_key"], table_name="objects")
workflow_metadata = _from_key_values(row["workflow_metadata"])
spectrum_metadata_values = _from_key_values(row["spectrum_metadata"])
extinction = dict(
spectrum_metadata_values.get("galactic_extinction")
or workflow_metadata.get("galactic_extinction")
or {}
)
extinction_status = extinction.get("status")
spectrum_metadata_values.setdefault(
"galactic_extinction_corrected",
extinction_status in (
"applied",
"declared_corrected",
"caller_preprocessed",
),
)
spectrum_metadata_values.setdefault("galactic_extinction", extinction)
spectrum_metadata_values.setdefault("ra", workflow_metadata.get("ra"))
spectrum_metadata_values.setdefault("dec", workflow_metadata.get("dec"))
spectrum_metadata = resolve_spectrum_metadata(
metadata=spectrum_metadata_values
)
spectrum = Spectrum.from_arrays(
np.asarray(row["wave_rest"], dtype=float),
np.asarray(row["flux"], dtype=float),
err=np.asarray(row["error"], dtype=float),
z=float(row["redshift"]),
wave_frame="rest",
mask=(
np.asarray(row["input_mask"], dtype=bool)
if row["input_mask"] is not None else None
),
metadata=spectrum_metadata,
)
total_spectrum = None
if row["total_flux"] is not None:
total_spectrum = Spectrum.from_arrays(
np.asarray(row["wave_rest"], dtype=float),
np.asarray(row["total_flux"], dtype=float),
err=np.asarray(row["error"], dtype=float),
z=float(row["redshift"]),
wave_frame="rest",
mask=(
np.asarray(row["input_mask"], dtype=bool)
if row["input_mask"] is not None else None
),
metadata=spectrum_metadata,
)
continuum_components = {
item["name"]: np.asarray(item["values"], dtype=float)
for item in row["components"]
if item["section"] == "continuum"
}
continuum_model = sum(
continuum_components.values(),
np.zeros_like(spectrum.flux, dtype=float),
)
continuum_values, continuum_errors = _measurement_maps(
store, row["object_key"], "continuum_parameter", None
)
warning_rows = [
item
for item in store.read_table("warnings").to_pylist()
if item["object_key"] == row["object_key"]
]
def archived_warnings(section: str, recipe_id=None) -> list[FitWarning]:
return [
FitWarning(
code=item["code"],
message=item["message"],
severity=item["severity"],
context=_from_key_values(item["context"]),
)
for item in warning_rows
if item["section"] == section
and item.get("recipe_id") == recipe_id
]
host_enabled = bool(object_row["host_decomp_enabled"])
host_fit_mask, host_emission_mask, host_mask_provenance = (
_archived_host_masks(row, spectrum.wave_rest)
)
workflow_metadata["host_mask_provenance"] = host_mask_provenance
continuum = GlobalContinuumResult(
success=bool(object_row["continuum_success"]),
status=1 if object_row["continuum_success"] else -1,
message="Loaded from Parquet model archive.",
param_values=continuum_values,
param_errors=continuum_errors,
covariance=None,
chi2=np.nan,
dof=0,
reduced_chi2=float(object_row["continuum_reduced_chi2"]),
wave_rest=spectrum.wave_rest.copy(),
model=continuum_model,
component_models=continuum_components,
fit_mask=np.asarray(row["continuum_fit_mask"], dtype=bool),
clip_mask=np.asarray(row["continuum_clip_mask"], dtype=bool),
warnings=archived_warnings("continuum"),
metadata=workflow_metadata,
)
complex_components: Dict[str, Dict[str, np.ndarray]] = {}
for component in row["components"]:
if component["section"] == "complex":
complex_components.setdefault(component["recipe_id"], {})[
component["name"]
] = np.asarray(component["values"], dtype=float)
complexes: Dict[str, EmissionComplexResult] = {}
for item in row["complexes"]:
recipe_id = item["recipe_id"]
parameters, parameter_errors = _measurement_maps(
store, row["object_key"], "complex_parameter", recipe_id
)
metrics, metric_errors = _measurement_maps(
store, row["object_key"], "complex_metric", recipe_id
)
complex_metadata = _from_key_values(item["metadata"])
component_models = complex_components.get(recipe_id, {})
complex_model = sum(
component_models.values(),
np.zeros_like(spectrum.flux, dtype=float),
)
complexes[recipe_id] = EmissionComplexResult(
success=bool(item["success"]),
status=int(item["status"]),
message=str(item["message"]),
selected_model=str(item["selected_model"]),
param_values=parameters,
param_errors=parameter_errors,
covariance=None,
metrics=metrics,
metric_errors=metric_errors,
chi2=float(item["chi2"]),
dof=int(item["dof"]),
reduced_chi2=float(item["reduced_chi2"]),
bic=float(item["bic"]),
wave_rest=spectrum.wave_rest.copy(),
flux_continuum_subtracted=spectrum.flux - continuum_model,
err=spectrum.err.copy(),
model=complex_model,
component_models=component_models,
fit_mask=np.asarray(item["fit_mask"], dtype=bool),
warnings=archived_warnings("complex", recipe_id),
metadata={
"recipe_id": recipe_id,
**complex_metadata,
},
excluded_mask=np.asarray(item["excluded_mask"], dtype=bool),
)
warnings = archived_warnings("workflow")
host_warnings = [
item["message"]
for item in warning_rows
if item["section"] == "host"
]
workflow = WorkflowResult(
spectrum=spectrum,
continuum_initial=continuum,
continuum=continuum,
hbeta=complexes.get("hbeta_oiii"),
hbeta_initial=complexes.get("hbeta_oiii"),
mgii=complexes.get("mgii"),
halpha=complexes.get("halpha_nii_sii"),
line_complexes=complexes,
complex_statuses={
key: str(value)
for key, value in _from_key_values(
object_row["complex_statuses"]
).items()
},
host_decomp_enabled=host_enabled,
total_spectrum=total_spectrum,
host_model_on_quasar_grid=(
np.asarray(row["host_model"], dtype=float)
if row["host_model"] is not None else None
),
host_fit_mask=host_fit_mask,
host_emission_mask=host_emission_mask,
host_warnings=host_warnings,
warnings=warnings,
metadata=workflow_metadata,
)
return workflow
[docs]
def finalize_run(
run: Union[str, RunStore],
*,
compact_models: bool = False,
) -> Dict[str, str]:
"""Validate canonical datasets and finalize a resumable run."""
store = open_run(run) if isinstance(run, str) else run
outputs = {}
object_keys = store.read_table("objects", columns=["object_key"]).column(
"object_key"
).to_pylist()
duplicates = sorted(
{key for key in object_keys if object_keys.count(key) > 1}
)
if duplicates:
raise ValueError(f"Duplicate object keys in run: {duplicates[:10]}")
for name in TABLE_NAMES:
store.read_table(name)
table_path = store._table_path(name)
if table_path.exists():
outputs[name] = str(table_path)
staging = store._staging_path()
if staging.exists() and not any(staging.iterdir()):
staging.rmdir()
store.manifest["status"] = "complete"
store.manifest["finalized_at"] = _now()
store.manifest["datasets"] = outputs
store.manifest.pop("compact_outputs", None)
store._write_manifest()
return outputs
[docs]
def build_science_catalog(
run: Union[str, RunStore],
specification: Optional[Mapping[str, Mapping[str, Any]]] = None,
*,
output_path: Optional[str] = None,
) -> pd.DataFrame:
"""Materialize a provisional wide catalog from long measurements."""
store = open_run(run) if isinstance(run, str) else run
objects = store.read_table("objects").to_pandas()
if not specification:
return objects
measurements = store.read_table("measurements").to_pandas()
catalog = objects.copy()
for output_name, selector in specification.items():
selected = measurements.copy()
for key in ("section", "recipe_id", "feature_id", "role", "quantity"):
if selector.get(key) is not None:
selected = selected[selected[key] == selector[key]]
values = selected.set_index("object_key")["value"]
errors = selected.set_index("object_key")["error"]
catalog[output_name] = catalog["object_key"].map(values)
if selector.get("include_error", True):
catalog[f"{output_name}_err"] = catalog["object_key"].map(errors)
if output_path is not None:
catalog.to_parquet(output_path, index=False)
return catalog
[docs]
def compute_derived_quantities(
run: Union[str, RunStore],
calculators: Mapping[str, Callable[[Mapping[str, Any]], Any]],
) -> pd.DataFrame:
"""Run calibration-neutral user calculators and archive long-form results."""
store = open_run(run) if isinstance(run, str) else run
rows = []
all_measurements = store.read_table("measurements").to_pylist()
measurements_by_object: Dict[str, list[dict[str, Any]]] = {}
for measurement in all_measurements:
measurements_by_object.setdefault(
measurement["object_key"], []
).append(measurement)
for object_row in store.read_table("objects").to_pylist():
key = object_row["object_key"]
measurements = measurements_by_object.get(key, [])
context = {"object": object_row, "measurements": measurements}
for calibration_id, calculator in calculators.items():
calculated = calculator(context)
entries = calculated if isinstance(calculated, list) else [calculated]
for entry in entries:
rows.append(
{
"run_id": store.run_id,
"object_key": key,
"object_id": object_row["object_id"],
"quantity": str(entry["quantity"]),
"calibration_id": str(calibration_id),
"value": _float(entry.get("value")),
"statistical_error": _float(
entry.get("statistical_error")
),
"intrinsic_scatter": _float(
entry.get("intrinsic_scatter")
),
"total_error": _float(entry.get("total_error")),
"unit": entry.get("unit"),
"metadata": _key_values(entry.get("metadata", {})),
}
)
if rows:
by_object: Dict[str, list[dict[str, Any]]] = {}
for row in rows:
by_object.setdefault(row["object_key"], []).append(row)
for object_rows in by_object.values():
store.write_payload({"derived": object_rows})
return pa.Table.from_pylist(rows, schema=SCHEMAS["derived"]).to_pandas()