Source code for qsospec.workflows.batch

"""Single-object and resumable process-parallel qsospec execution."""

from __future__ import annotations

from concurrent.futures import FIRST_COMPLETED, ProcessPoolExecutor, wait
from dataclasses import asdict, dataclass
import hashlib
import json
import os
from pathlib import Path
import traceback
from typing import Any, Dict, Iterator, Optional, Sequence, Union

import numpy as np

from .host.io import SpectrumData

from ..complex_recipes import ComplexRecipe
from ..config import (
    GalacticExtinctionConfig,
    GlobalContinuumConfig,
    HalphaComplexConfig,
    HbetaComplexConfig,
    LyaNVComplexConfig,
    MgIIComplexConfig,
    UncertaintyConfig,
)
from ..extinction import (
    correct_spectrum_data,
    prepare_spectrum,
    preflight_galactic_extinction,
)
from ..fitting.global_fit import fit_global_lines
from ..io.products import GlobalQAPlotConfig, write_global_line_products
from ..global_result import WorkflowResult
from .host_workflow import (
    _host_decomp_decision,
    _host_subtracted_spectrum,
    _spectrum_from_spectrum_data,
)
from ..io.readers import (
    SpectrumInput,
    discover_fits_inputs,
    read_input_manifest,
    read_spectrum,
    scan_parquet_spectra,
)
from ..io.run_store import RunStore, finalize_run, workflow_payload
from ..spectrum import Spectrum
from ..warnings import FitWarning


[docs] @dataclass class BatchResult: """Summary of one batch invocation.""" run_directory: str run_id: str n_submitted: int n_completed: int n_failed: int n_skipped: int n_workers: int datasets: Dict[str, str]
@dataclass class _Task: descriptor: SpectrumInput spectrum_data: Optional[SpectrumData] run_directory: str fit_options: Dict[str, Any] legacy_output: bool = False def _auto_workers(number_of_objects: Optional[int]) -> int: cpu_count = os.cpu_count() or 1 available = max(cpu_count - 1, 1) if number_of_objects is not None: available = min(available, max(number_of_objects, 1)) return min(available, 8) def _worker_initializer() -> None: for variable in ( "OMP_NUM_THREADS", "OPENBLAS_NUM_THREADS", "MKL_NUM_THREADS", "VECLIB_MAXIMUM_THREADS", "NUMEXPR_NUM_THREADS", ): os.environ[variable] = "1" def _process_pool_available() -> bool: try: import multiprocessing.synchronize # noqa: F401 os.sysconf("SC_SEM_NSEMS_MAX") except (ImportError, NotImplementedError, OSError, PermissionError, ValueError): return False return True def _seed_for(run_id: str, object_key: str, base_seed: Optional[int]) -> int: payload = f"{run_id}|{object_key}|{base_seed}".encode("utf-8") return int.from_bytes(hashlib.sha256(payload).digest()[:4], "little") def _fit_spectrum_data( spectrum_data: SpectrumData, *, descriptor: SpectrumInput, run_host_decomp: bool, template_root: str, template_file: str, host_fit_range, host_config, galactic_extinction_config, global_config, hbeta_config, mgii_config, halpha_config, lya_nv_config, uncertainty_config, complexes, ): spectrum_data = correct_spectrum_data( spectrum_data, galactic_extinction_config ) source = ( f"{descriptor.source}:row_index={descriptor.row_index}" if descriptor.row_index is not None else descriptor.source ) host_decomp_enabled, host_skip_reason = _host_decomp_decision( run_host_decomp, spectrum_data.redshift ) if host_decomp_enabled: ( total_spectrum, fit_spectrum, host_fit, host_sed, host_on_grid, _, host_warnings, ) = _host_subtracted_spectrum( spectrum_data, redshift=float(spectrum_data.redshift), template_root=template_root, template_file=template_file, fit_range=host_fit_range, host_config=host_config, source=source, ) else: total_spectrum = _spectrum_from_spectrum_data( spectrum_data, source=source ) fit_spectrum = total_spectrum host_fit = None host_sed = None host_on_grid = None host_warnings = [] result = fit_global_lines( fit_spectrum, global_config, hbeta_config, mgii_config, halpha_config, uncertainty_config, lya_nv_config=lya_nv_config, host_model_on_grid=host_on_grid, complexes=complexes, ) result.host_decomp_enabled = host_decomp_enabled result.total_spectrum = total_spectrum result.host_fit = host_fit result.host_sed = host_sed result.host_model_on_quasar_grid = host_on_grid result.host_fit_mask = ( np.asarray(host_fit.preprocessed.fit_mask, dtype=bool).copy() if host_fit is not None else None ) result.host_emission_mask = ( np.asarray(host_fit.preprocessed.emission_mask, dtype=bool).copy() if host_fit is not None else None ) result.host_warnings = [str(item) for item in host_warnings] object_id = ( descriptor.object_id or spectrum_data.object_id or spectrum_data.targetid or Path(descriptor.source).stem ) result.metadata.update( { "input_path": descriptor.source, "row_index": descriptor.row_index, "object_id": str(object_id), "targetid": spectrum_data.targetid, "ra": spectrum_data.ra, "dec": spectrum_data.dec, "redshift": fit_spectrum.z, "fit_kind": "global", "host_decomp_requested": bool(run_host_decomp), "host_decomp_enabled": host_decomp_enabled, "host_decomp_skip_reason": host_skip_reason, "host_model_source": ( "template_weighted_sed_on_quasar_grid" if host_decomp_enabled else None ), "host_fit_range": list(host_fit_range), "host_mask_provenance": ( "exact" if host_decomp_enabled else "unavailable" ), "host_ppxf_status": ( host_fit.status if host_fit is not None else None ), "host_ppxf_reduced_chi2": ( float(host_fit.reduced_chi2) if host_fit is not None else None ), "host_fit_reliable": ( bool(host_fit.host_fit_reliable) if host_fit is not None else None ), "host_fit_reliability_reasons": ( list(host_fit.host_fit_reliability_reasons) if host_fit is not None else [] ), "host_fit_quality": ( dict(host_fit.quality_metrics) if host_fit is not None else {} ), "host_noise_rescale_factors": ( dict(host_fit.noise_rescale_factors) if host_fit is not None else {} ), "host_mask_components_log": ( { key: np.asarray(value, dtype=bool).tolist() for key, value in host_fit.preprocessed.mask_provenance.items() if str(key).endswith("_log") or str(key) == "log_grid_valid" } if host_fit is not None else {} ), "host_mask_component_counts": ( { key: int(np.count_nonzero(value)) for key, value in host_fit.preprocessed.mask_provenance.items() } if host_fit is not None else {} ), "host_template_file": ( host_fit.templates.source_path if host_fit is not None else None ), "host_template_wavelength_coverage": ( list(host_fit.templates.wavelength_coverage) if host_fit is not None else None ), "galactic_extinction": dict( spectrum_data.metadata.get("galactic_extinction", {}) ), } ) if run_host_decomp and not host_decomp_enabled: result.warnings.append( FitWarning( code="host_decomp_skipped_redshift", message="Host decomposition was requested but skipped by the redshift gate.", severity="info", context={ "redshift": spectrum_data.redshift, "threshold": 1.2, "reason": host_skip_reason, }, ) ) return result, str(object_id) def _failure_payload( store: RunStore, descriptor: SpectrumInput, exception: BaseException, ) -> Dict[str, list[dict[str, Any]]]: return { "inputs": [ { "run_id": store.run_id, "object_key": descriptor.object_key, "object_id": descriptor.object_id, "source": descriptor.source, "row_index": descriptor.row_index, "reader": descriptor.reader, "redshift": descriptor.redshift, "metadata": [ { "key": str(key), "value": json.dumps(value, sort_keys=True, default=repr), } for key, value in sorted(descriptor.metadata.items()) ], } ], "failures": [ { "run_id": store.run_id, "object_key": descriptor.object_key, "object_id": descriptor.object_id, "source": descriptor.source, "row_index": descriptor.row_index, "exception_type": type(exception).__name__, "message": str(exception), "traceback": traceback.format_exc(), "failed_at": __import__("datetime").datetime.now( __import__("datetime").timezone.utc ).isoformat(), "metadata": [], } ] } def _run_task(task: _Task) -> Dict[str, Any]: _worker_initializer() store = RunStore.open(task.run_directory) descriptor = task.descriptor try: spectrum_data = task.spectrum_data or read_spectrum( descriptor.source, row_index=descriptor.row_index, redshift=descriptor.redshift, object_id=descriptor.object_id, reader=descriptor.reader, ) options = dict(task.fit_options) uncertainty = options["uncertainty_config"] options["uncertainty_config"] = UncertaintyConfig( covariance=uncertainty.covariance, monte_carlo_trials=uncertainty.monte_carlo_trials, random_seed=_seed_for( store.run_id, descriptor.object_key, uncertainty.random_seed, ), refit_host_in_mc=uncertainty.refit_host_in_mc, ) result, object_id = _fit_spectrum_data( spectrum_data, descriptor=descriptor, **options, ) payload = workflow_payload( result, run_id=store.run_id, object_key=descriptor.object_key, object_id=object_id, input_record={ "source": descriptor.source, "row_index": descriptor.row_index, "reader": descriptor.reader, "metadata": dict(descriptor.metadata), }, ) staging = store.stage_payload(payload) legacy_files = {} if task.legacy_output: legacy_files = write_global_line_products( result, str(store.path / "legacy" / object_id), ) return { "success": True, "object_key": descriptor.object_key, "object_id": object_id, "staging": str(staging), "legacy_files": legacy_files, } except Exception as exception: staging = store.stage_payload( _failure_payload(store, descriptor, exception) ) return { "success": False, "object_key": descriptor.object_key, "object_id": descriptor.object_id, "staging": str(staging), "error": str(exception), } def _run_task_group(tasks: Sequence[_Task]) -> list[Dict[str, Any]]: """Run one Parquet microbatch or one FITS task inside a worker.""" return [_run_task(task) for task in tasks] def _configuration( *, run_host_decomp, template_root, template_file, host_fit_range, host_config, galactic_extinction_config, global_config, hbeta_config, mgii_config, halpha_config, lya_nv_config, uncertainty_config, complexes, ) -> Dict[str, Any]: return { "run_host_decomp": bool(run_host_decomp), "template_root": str(template_root), "template_file": str(template_file), "host_fit_range": tuple(host_fit_range), "host_config": host_config, "galactic_extinction_config": asdict(galactic_extinction_config), "global_config": ( asdict(global_config) if global_config is not None else {"preset": "automatic_lya_safe"} ), "hbeta_config": asdict(hbeta_config), "mgii_config": asdict(mgii_config), "halpha_config": asdict(halpha_config), "lya_nv_config": asdict(lya_nv_config), "uncertainty_config": asdict(uncertainty_config), "complexes": [ asdict(item) if isinstance(item, ComplexRecipe) else str(item) for item in complexes ] if complexes is not None else None, } def _configuration_overrides(configuration: Dict[str, Any]) -> Dict[str, Any]: defaults = _configuration( run_host_decomp=False, template_root="~/tools/ppxf_data", template_file="spectra_emiles_9.0.npz", host_fit_range=(3600.0, 7000.0), host_config=None, galactic_extinction_config=GalacticExtinctionConfig(), global_config=None, hbeta_config=HbetaComplexConfig(), mgii_config=MgIIComplexConfig(), halpha_config=HalphaComplexConfig(), lya_nv_config=LyaNVComplexConfig(), uncertainty_config=UncertaintyConfig(), complexes=None, ) def diff(value, default): if isinstance(value, dict) and isinstance(default, dict): changed = { key: diff(item, default.get(key)) for key, item in value.items() if key not in default or item != default.get(key) } return {key: item for key, item in changed.items() if item != {}} return value overrides = {} for key, value in configuration.items(): default_value = defaults.get(key) if ( key == "global_config" and isinstance(value, dict) and "preset" not in value ): default_value = asdict(GlobalContinuumConfig()) if key not in defaults or value != default_value: changed = diff(value, default_value) if changed != {}: overrides[key] = changed return overrides def _fit_options( *, run_host_decomp, template_root, template_file, host_fit_range, host_config, galactic_extinction_config, global_config, hbeta_config, mgii_config, halpha_config, lya_nv_config, uncertainty_config, complexes, ) -> Dict[str, Any]: return { "run_host_decomp": bool(run_host_decomp), "template_root": template_root, "template_file": template_file, "host_fit_range": tuple(host_fit_range), "host_config": host_config, "galactic_extinction_config": galactic_extinction_config, "global_config": global_config, "hbeta_config": hbeta_config, "mgii_config": mgii_config, "halpha_config": halpha_config, "lya_nv_config": lya_nv_config, "uncertainty_config": uncertainty_config, "complexes": complexes, }
[docs] def fit_object_to_store( input_data: Union[str, SpectrumInput, SpectrumData, Spectrum], run_directory: str, *, row_index: Optional[int] = None, redshift: Optional[float] = None, object_id: Optional[str] = None, reader: str = "auto", flux_unit: Optional[str] = None, flux_scale: Optional[float] = None, run_host_decomp: bool = False, template_root: str = "~/tools/ppxf_data", template_file: str = "spectra_emiles_9.0.npz", host_fit_range=(3600.0, 7000.0), host_config=None, galactic_extinction_config: Optional[ GalacticExtinctionConfig ] = None, global_config: Optional[GlobalContinuumConfig] = None, hbeta_config: Optional[HbetaComplexConfig] = None, mgii_config: Optional[MgIIComplexConfig] = None, halpha_config: Optional[HalphaComplexConfig] = None, lya_nv_config: Optional[LyaNVComplexConfig] = None, uncertainty_config: Optional[UncertaintyConfig] = None, complexes: Optional[Sequence[Union[str, ComplexRecipe]]] = None, run_id: Optional[str] = None, resume: bool = True, write_qa: bool = True, qa_plot_config: Optional[GlobalQAPlotConfig] = None, write_legacy_products: bool = False, ) -> WorkflowResult: """Fit one object into the same run bundle used for batch fitting.""" hbeta_config = hbeta_config or HbetaComplexConfig() mgii_config = mgii_config or MgIIComplexConfig() halpha_config = halpha_config or HalphaComplexConfig() lya_nv_config = lya_nv_config or LyaNVComplexConfig() uncertainty_config = uncertainty_config or UncertaintyConfig() galactic_extinction_config = ( galactic_extinction_config or GalacticExtinctionConfig() ) configuration = _configuration( run_host_decomp=run_host_decomp, template_root=template_root, template_file=template_file, host_fit_range=host_fit_range, host_config=host_config, galactic_extinction_config=galactic_extinction_config, global_config=global_config, hbeta_config=hbeta_config, mgii_config=mgii_config, halpha_config=halpha_config, lya_nv_config=lya_nv_config, uncertainty_config=uncertainty_config, complexes=complexes, ) store = RunStore.create( run_directory, configuration=configuration, configuration_summary=_configuration_overrides(configuration), run_id=run_id, resume=resume, ) if isinstance(input_data, SpectrumInput): descriptor = input_data spectrum_data = read_spectrum( descriptor.source, row_index=descriptor.row_index, redshift=descriptor.redshift, object_id=descriptor.object_id, reader=descriptor.reader, flux_unit=flux_unit, flux_scale=flux_scale, ) elif isinstance(input_data, SpectrumData): descriptor = SpectrumInput( source=str(input_data.metadata.get("input_file", "in_memory")), object_id=object_id or input_data.object_id or input_data.targetid, redshift=redshift or input_data.redshift, reader="memory", ) spectrum_data = input_data elif isinstance(input_data, Spectrum): prepared_spectrum = prepare_spectrum( input_data, galactic_extinction_config=galactic_extinction_config, ) descriptor = SpectrumInput( source=prepared_spectrum.metadata.source or "in_memory", object_id=object_id, redshift=prepared_spectrum.z, reader="memory", ) spectrum_data = SpectrumData( wave_obs=prepared_spectrum.wave_obs, flux=prepared_spectrum.flux, error=prepared_spectrum.err, mask=prepared_spectrum.mask, redshift=prepared_spectrum.z, object_id=object_id, ra=prepared_spectrum.metadata.ra, dec=prepared_spectrum.metadata.dec, metadata={ "input_file": descriptor.source, "flux_unit": prepared_spectrum.flux_unit, "flux_scale": prepared_spectrum.flux_scale, "flux_frame": prepared_spectrum.flux_frame, "rest_frame_conversion": dict( prepared_spectrum.metadata.rest_frame_conversion ), "spectrum_metadata": ( prepared_spectrum.metadata.to_dict() ), "galactic_extinction": dict( prepared_spectrum.metadata.galactic_extinction ), }, ) else: descriptor = SpectrumInput( source=str(Path(input_data).expanduser()), row_index=row_index, object_id=object_id, redshift=redshift, reader=reader, ) spectrum_data = read_spectrum( descriptor.source, row_index=row_index, redshift=redshift, object_id=object_id, reader=reader, flux_unit=flux_unit, flux_scale=flux_scale, ) result, actual_object_id = _fit_spectrum_data( spectrum_data, descriptor=descriptor, **_fit_options( run_host_decomp=run_host_decomp, template_root=template_root, template_file=template_file, host_fit_range=host_fit_range, host_config=host_config, galactic_extinction_config=galactic_extinction_config, global_config=global_config, hbeta_config=hbeta_config, mgii_config=mgii_config, halpha_config=halpha_config, lya_nv_config=lya_nv_config, uncertainty_config=uncertainty_config, complexes=complexes, ), ) store.write_payload( workflow_payload( result, run_id=store.run_id, object_key=descriptor.object_key, object_id=actual_object_id, input_record={ "source": descriptor.source, "row_index": descriptor.row_index, "reader": descriptor.reader, "metadata": dict(descriptor.metadata), }, ) ) finalize_run(store) result.output_files = { "run_directory": str(store.path), } result.output_files["manifest"] = str(store.path / "manifest.json") if write_legacy_products: config = qa_plot_config or GlobalQAPlotConfig() files = write_global_line_products( result, str(store.path / "legacy" / actual_object_id), config, ) result.output_files.update(files) elif write_qa: from ..io.qa import render_qa rendered = render_qa( store, object_ids=[actual_object_id], plot_config=qa_plot_config or GlobalQAPlotConfig(), ) result.output_files.update(rendered.get(actual_object_id, {})) result.output_files = { key: value for key, value in result.output_files.items() if key in ("run_directory", "manifest", "main_qa") } return result
def _iter_inputs( inputs, *, row_indices, filter_expression, parquet_batch_size, ) -> Iterator[tuple[SpectrumInput, Optional[SpectrumData]]]: items = [inputs] if isinstance(inputs, (str, Path)) else list(inputs) parquet_spectra = [] manifest_descriptors = [] remaining = [] for item in items: if isinstance(item, SpectrumInput): manifest_descriptors.append(item) continue if not isinstance(item, (str, Path)): remaining.append(item) continue path = Path(item).expanduser() suffix = path.suffix.lower() if suffix == ".csv": manifest_descriptors.extend(read_input_manifest(str(path))) elif suffix == ".parquet": import pyarrow.parquet as pq columns = {name.lower() for name in pq.read_schema(path).names} has_wave = any( alias in columns for alias in ("wavelength", "wave", "lambda", "lam", "obs_wave") ) has_flux = any( alias in columns for alias in ("flux", "flam", "flux_lambda") ) if has_wave and has_flux: parquet_spectra.append(str(path)) else: manifest_descriptors.extend(read_input_manifest(str(path))) else: remaining.append(item) if parquet_spectra: yield from scan_parquet_spectra( parquet_spectra, row_indices=row_indices, filter_expression=filter_expression, batch_size=parquet_batch_size, ) for descriptor in manifest_descriptors: yield descriptor, None discoverable = [ str(item) for item in remaining if isinstance(item, (str, Path)) ] if discoverable: for descriptor in discover_fits_inputs(discoverable): yield descriptor, None
[docs] def fit_batch( inputs, run_directory: str, *, row_indices=None, filter_expression=None, parquet_batch_size: int = 128, task_size: int = 8, n_workers: Union[int, str] = "auto", num_shards: int = 1, shard_index: int = 0, run_host_decomp: bool = False, template_root: str = "~/tools/ppxf_data", template_file: str = "spectra_emiles_9.0.npz", host_fit_range=(3600.0, 7000.0), host_config=None, galactic_extinction_config: Optional[ GalacticExtinctionConfig ] = None, global_config: Optional[GlobalContinuumConfig] = None, hbeta_config: Optional[HbetaComplexConfig] = None, mgii_config: Optional[MgIIComplexConfig] = None, halpha_config: Optional[HalphaComplexConfig] = None, lya_nv_config: Optional[LyaNVComplexConfig] = None, uncertainty_config: Optional[UncertaintyConfig] = None, complexes: Optional[Sequence[Union[str, ComplexRecipe]]] = None, run_id: Optional[str] = None, resume: bool = True, retry_failures: bool = True, finalize: bool = True, compact_models: bool = False, write_legacy_products: bool = False, ) -> BatchResult: """Fit a Parquet or FITS sample with resumable process parallelism.""" if num_shards < 1 or not 0 <= shard_index < num_shards: raise ValueError("Require num_shards >= 1 and 0 <= shard_index < num_shards.") hbeta_config = hbeta_config or HbetaComplexConfig() mgii_config = mgii_config or MgIIComplexConfig() halpha_config = halpha_config or HalphaComplexConfig() lya_nv_config = lya_nv_config or LyaNVComplexConfig() uncertainty_config = uncertainty_config or UncertaintyConfig() galactic_extinction_config = ( galactic_extinction_config or GalacticExtinctionConfig() ) preflight_galactic_extinction(galactic_extinction_config) configuration = _configuration( run_host_decomp=run_host_decomp, template_root=template_root, template_file=template_file, host_fit_range=host_fit_range, host_config=host_config, galactic_extinction_config=galactic_extinction_config, global_config=global_config, hbeta_config=hbeta_config, mgii_config=mgii_config, halpha_config=halpha_config, lya_nv_config=lya_nv_config, uncertainty_config=uncertainty_config, complexes=complexes, ) configuration["num_shards"] = int(num_shards) store = RunStore.create( run_directory, configuration=configuration, configuration_summary=_configuration_overrides(configuration), run_id=run_id, resume=resume, ) completed = store.completed_keys() if resume else set() failed = store.failed_keys() if resume and not retry_failures else set() options = _fit_options( run_host_decomp=run_host_decomp, template_root=template_root, template_file=template_file, host_fit_range=host_fit_range, host_config=host_config, galactic_extinction_config=galactic_extinction_config, global_config=global_config, hbeta_config=hbeta_config, mgii_config=mgii_config, halpha_config=halpha_config, lya_nv_config=lya_nv_config, uncertainty_config=uncertainty_config, complexes=complexes, ) iterator = _iter_inputs( inputs, row_indices=row_indices, filter_expression=filter_expression, parquet_batch_size=parquet_batch_size, ) def selected(): for descriptor, spectrum_data in iterator: digest = int( hashlib.sha256(descriptor.object_key.encode("utf-8")).hexdigest(), 16, ) if digest % num_shards != shard_index: continue if descriptor.object_key in completed or descriptor.object_key in failed: yield None continue yield _Task( descriptor=descriptor, spectrum_data=spectrum_data, run_directory=str(store.path), fit_options=options, legacy_output=write_legacy_products, ) def grouped_tasks(): parquet_group = [] for task in selected(): if task is None: yield None continue if task.spectrum_data is None: if parquet_group: yield parquet_group parquet_group = [] yield [task] continue parquet_group.append(task) if len(parquet_group) >= max(int(task_size), 1): yield parquet_group parquet_group = [] if parquet_group: yield parquet_group task_iterator = iter(grouped_tasks()) worker_count = _auto_workers(None) if n_workers == "auto" else int(n_workers) worker_count = max(worker_count, 1) if worker_count > 1 and not _process_pool_available(): worker_count = 1 submitted = completed_count = failed_count = skipped_count = 0 def handle(output): nonlocal completed_count, failed_count store.promote(output["staging"]) if output["success"]: store.clear_failure(output["object_key"]) completed_count += 1 else: failed_count += 1 if worker_count == 1: for task_group in task_iterator: if task_group is None: skipped_count += 1 continue submitted += len(task_group) for output in _run_task_group(task_group): handle(output) else: _worker_initializer() import multiprocessing as mp with ProcessPoolExecutor( max_workers=worker_count, mp_context=mp.get_context("spawn"), initializer=_worker_initializer, ) as executor: pending = {} exhausted = False while pending or not exhausted: while len(pending) < worker_count * 2: try: task_group = next(task_iterator) except StopIteration: exhausted = True break if task_group is None: skipped_count += 1 continue submitted += len(task_group) future = executor.submit(_run_task_group, task_group) pending[future] = task_group if pending: done, _ = wait( set(pending), return_when=FIRST_COMPLETED ) for future in done: task_group = pending.pop(future) try: outputs = future.result() except Exception as exception: outputs = [] for task in task_group: store.write_payload( _failure_payload( store, task.descriptor, exception ) ) failed_count += 1 for output in outputs: handle(output) datasets = ( finalize_run(store) if finalize and num_shards == 1 else {} ) return BatchResult( run_directory=str(store.path), run_id=store.run_id, n_submitted=submitted, n_completed=completed_count, n_failed=failed_count, n_skipped=skipped_count, n_workers=worker_count, datasets=datasets, )