# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Orchestrate multiple benchmark experiments and collect/plot results."""
from __future__ import annotations
import json
import logging
import typing as tp
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from pydantic import Field
from tqdm import tqdm
import neuralset as ns
from .plots.benchmark import plot_all_results
from .plots.tables import print_skip_table
if tp.TYPE_CHECKING:
from .main import Experiment
LOGGER = logging.getLogger(__name__)
def _default_output_dir() -> str:
"""Resolve the default output directory under the user-configured ``SAVE_DIR``.
Imported lazily so that ``BenchmarkAggregator`` can be imported before
``neuralbench.config_manager`` has been initialised (avoids triggering
interactive setup at import time).
"""
from neuralbench.config_manager import get_config
return str(Path(get_config()["SAVE_DIR"]) / "outputs")
[docs]
class BenchmarkAggregator(ns.BaseModel):
"""Orchestrate multiple :class:`Experiment` runs and visualise results.
Experiments are submitted (possibly via Slurm) with :meth:`prepare`,
collected with :meth:`_collect_results`, and plotted/tabled via the
functions in :mod:`neuralbench.plots.benchmark`.
"""
experiments: list["Experiment"]
max_workers: int = 256
collect_max_workers: int = 32
debug: bool = False
output_dir: str = Field(default_factory=_default_output_dir)
loss_to_metric_mapping: dict[str, str] = {
"CrossEntropyLoss": "test/bal_acc",
"BCEWithLogitsLoss": "test/f1_score_macro",
"MSELoss": "test/pearsonr",
"ClipLoss": "test/full_retrieval/top5_acc_subject-agg",
}
[docs]
def prepare(self) -> None:
n_total = len(self.experiments)
statuses = [exp.infra.status() for exp in self.experiments]
n_completed = statuses.count("completed")
n_running = statuses.count("running")
n_cached = n_completed + n_running
is_force = self.experiments[0].infra.mode == "force"
if n_cached == n_total:
parts = []
if n_completed:
parts.append(f"{n_completed} completed")
if n_running:
parts.append(f"{n_running} still running")
if is_force:
LOGGER.info(
"All %d experiment(s) already cached/running (%s). "
"Re-running with --force.",
n_total,
", ".join(parts),
)
else:
LOGGER.info(
"All %d experiment(s) already cached/running (%s). "
"Nothing to launch. Use --force to re-run.",
n_total,
", ".join(parts),
)
return
if self.debug:
for experiment in self.experiments:
experiment.run()
else:
tmp = self.experiments[0].infra.clone_obj()
with tmp.infra.job_array(max_workers=self.max_workers) as tasks:
tasks.extend(self.experiments)
@staticmethod
def _process_one_experiment(
experiment: "Experiment", cached_only: bool
) -> tuple[dict[str, tp.Any] | None, str, str]:
"""Process one experiment, returning ``(result_or_None, task, model)``."""
task = experiment.task_name
model = experiment.brain_model_name
if cached_only and experiment.infra.status() != "completed":
return None, task, model
out = experiment.run()
out["task_name"] = experiment.task_name
study = experiment.data.study
if isinstance(study, ns.Chain):
out["dataset_name"] = type(study[0]).__name__
else:
out["dataset_name"] = type(study).__name__
out["brain_model_name"] = experiment.brain_model_name
out["loss"] = {"name": type(experiment.loss).__name__}
out["seed"] = experiment.seed
return out, task, model
def _collect_results(self, cached_only: bool = False) -> list[dict[str, tp.Any]]:
"""Gather experiment results, optionally skipping uncached ones.
When *cached_only* is ``True`` the work is purely I/O-bound (pickle
loads), so experiments are processed in parallel using threads.
"""
if cached_only:
return self._collect_results_parallel()
return self._collect_results_sequential(cached_only=False)
def _collect_results_sequential(
self, cached_only: bool = False
) -> list[dict[str, tp.Any]]:
results: list[dict[str, tp.Any]] = []
total: dict[tuple[str, str], int] = {}
skipped: dict[tuple[str, str], int] = {}
for experiment in self.experiments:
out, task, model = self._process_one_experiment(experiment, cached_only)
key = (task, model)
total[key] = total.get(key, 0) + 1
if out is None:
skipped[key] = skipped.get(key, 0) + 1
else:
results.append(out)
if skipped:
print_skip_table(total, skipped)
return results
def _collect_results_parallel(self) -> list[dict[str, tp.Any]]:
results: list[dict[str, tp.Any]] = []
total: dict[tuple[str, str], int] = {}
skipped: dict[tuple[str, str], int] = {}
n_workers = min(self.collect_max_workers, len(self.experiments))
with ThreadPoolExecutor(max_workers=n_workers) as pool:
futures = {
pool.submit(self._process_one_experiment, exp, True): exp
for exp in self.experiments
}
for future in tqdm(
as_completed(futures),
total=len(futures),
desc="Collecting cached results",
):
out, task, model = future.result()
key = (task, model)
total[key] = total.get(key, 0) + 1
if out is None:
skipped[key] = skipped.get(key, 0) + 1
else:
results.append(out)
if skipped:
print_skip_table(total, skipped)
return results
def _save_computational_stats(self, results: list[dict[str, tp.Any]]) -> Path:
"""Write per-experiment computational stats to JSON for later analysis."""
_COMP_KEYS = (
"training_time_s",
"peak_gpu_memory_mb",
"peak_cpu_memory_mb",
"n_total_params",
"n_trainable_params",
)
stats = []
for r in results:
entry: dict[str, tp.Any] = {
"task_name": r.get("task_name"),
"dataset_name": r.get("dataset_name"),
"model": r.get("brain_model_name", "unknown"),
"seed": r.get("seed"),
}
entry.update({k: r.get(k) for k in _COMP_KEYS})
stats.append(entry)
out_path = Path(self.output_dir) / "other" / "computational_stats.json"
out_path.parent.mkdir(parents=True, exist_ok=True)
with open(out_path, "w") as f:
json.dump(stats, f, indent=2)
LOGGER.info("Saved computational stats to %s", out_path)
return out_path
[docs]
def run(self, cached_only: bool = False) -> list[dict[str, tp.Any]]:
results = self._collect_results(cached_only=cached_only)
if not results:
if cached_only:
# --plot-cached only finds canonical (non-debug) runs: --debug
# uses a reduced config (fewer epochs, smaller batch, subset
# query) and is therefore cached under a different key.
LOGGER.info(
"No cached results found. Nothing to plot.\n"
" --plot-cached looks for completed canonical runs only; "
"--debug runs are cached separately and are not picked up.\n"
" To populate the cache, run the canonical command first "
"(drop --debug), e.g.:\n"
" neuralbench <device> <task> -m <model>\n"
" then re-run with --plot-cached:\n"
" neuralbench <device> <task> -m <model> --plot-cached"
)
else:
LOGGER.info("No results found. Nothing to plot.")
return []
self._save_computational_stats(results)
plot_all_results(results, self.loss_to_metric_mapping, self.output_dir)
return results