Source code for qsospec.io.qa

"""Post-fit QA rendering from Parquet model archives."""

from __future__ import annotations

from dataclasses import asdict
from pathlib import Path
from typing import Dict, Optional, Sequence, Union

import numpy as np

from .products import (
    GlobalQAPlotConfig,
    _plot_paths,
    _plot_qa,
    _qa_object_name,
    _save_figure,
    _normalized_file_label,
)
from .run_store import RunStore, load_model, open_run


[docs] def render_qa( run: Union[str, RunStore], *, object_ids: Optional[Sequence[str]] = None, warning_codes: Optional[Sequence[str]] = None, query: Optional[str] = None, sample: Optional[int] = None, random_seed: int = 12345, include_failed: bool = False, plot_config: Optional[GlobalQAPlotConfig] = None, output_dir: Optional[str] = None, ) -> Dict[str, Dict[str, str]]: """Render main QA figures without refitting archived objects.""" store = open_run(run) if isinstance(run, str) else run objects = store.read_table("objects").to_pandas() if query: objects = objects.query(query) if object_ids is not None: requested = set(map(str, object_ids)) objects = objects[ objects["object_id"].astype(str).isin(requested) | objects["object_key"].astype(str).isin(requested) ] if warning_codes: requested_warnings = set(map(str, warning_codes)) objects = objects[ objects["warning_codes"].map( lambda values: bool(requested_warnings.intersection(values or ())) ) ] if sample is not None and len(objects) > int(sample): rng = np.random.default_rng(random_seed) indices = np.sort( rng.choice(len(objects), size=int(sample), replace=False) ) objects = objects.iloc[indices] config = plot_config or GlobalQAPlotConfig() destination = Path(output_dir).expanduser() if output_dir else store.path / "qa" destination.mkdir(parents=True, exist_ok=True) outputs: Dict[str, Dict[str, str]] = {} object_id_counts = objects["object_id"].astype(str).value_counts() inputs = store.read_table("inputs").to_pandas() row_indices = ( inputs.set_index("object_key")["row_index"].to_dict() if not inputs.empty else {} ) for row in objects.sort_values("object_key").to_dict("records"): result = load_model(store, row["object_key"]) object_id = str(row["object_id"]) duplicate_id = int(object_id_counts.get(object_id, 0)) > 1 row_index = row_indices.get(row["object_key"]) file_object_name = ( f"{object_id}_row_{row_index}" if duplicate_id else object_id ) if config.object_name in (None, ""): object_config = GlobalQAPlotConfig( **{ **asdict(config), "object_name": file_object_name, } ) else: object_config = config label = _qa_object_name(result, object_config) saved = _plot_qa( result, _plot_paths(destination, f"main_qa_{label}", object_config), object_config, ) primary = saved.get("png", next(iter(saved.values()))) output_key = str(row["object_key"]) if duplicate_id else object_id outputs[output_key] = { **saved, **{f"main_qa_{key}": value for key, value in saved.items()}, "main_qa": primary, "global_plot": primary, "qa_plot": primary, } if include_failed: import matplotlib.pyplot as plt failures = store.read_table("failures").to_pandas() if object_ids is not None: requested = set(map(str, object_ids)) failures = failures[ failures["object_id"].astype(str).isin(requested) | failures["object_key"].astype(str).isin(requested) ] for row in failures.sort_values("object_key").to_dict("records"): label = _normalized_file_label( row["object_id"] or row["object_key"] ) paths = _plot_paths(destination, f"failed_qa_{label}", config) fig, axis = plt.subplots( figsize=(config.figure_width, 3.2), constrained_layout=True, ) axis.axis("off") axis.text( 0.02, 0.95, f"Fit failed — {row['object_id'] or row['object_key']}", transform=axis.transAxes, ha="left", va="top", fontsize=13, ) axis.text( 0.02, 0.78, f"{row['exception_type']}: {row['message']}", transform=axis.transAxes, ha="left", va="top", fontsize=10, wrap=True, ) outputs[str(row["object_id"] or row["object_key"])] = _save_figure( fig, paths ) plt.close(fig) return outputs