# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""CLI entry point and programmatic API for NeuralBench.
Parse command-line arguments (``run_benchmark_cli``) or call the Python
API directly (``run_benchmark``). Task/model discovery, validation, and
config assembly are delegated to :mod:`neuralbench.registry` and
:mod:`neuralbench.experiment_config`.
"""
import argparse
import logging
import sys
import traceback
import typing as tp
from exca import ConfDict
from neuralbench.experiment_config import (
_warn_slurm_partition,
prepare_task_configs,
)
from neuralbench.registry import (
ALL_DEVICES,
ALL_DOWNSTREAM_WRAPPERS,
ALL_MODELS,
ALL_TASKS,
ALL_UNVALIDATED_TASKS,
DEFAULTS_DIR,
_expand_models,
_format_datasets_epilog,
_resolve_datasets,
_resolve_tasks,
_validate_inputs,
load_yaml_config,
)
logger = logging.getLogger(__name__)
[docs]
def run_benchmark(
device: str,
task: str | list[str],
*,
model: str | list[str] | None = None,
dataset: str | list[str] | None = None,
checkpoint: str | None = None,
downstream_wrapper: str | list[str] | None = None,
grid: bool = False,
debug: bool = False,
force: bool = False,
retry: bool = False,
prepare: bool = False,
download: bool = False,
plot_cached: bool = False,
) -> list[dict[str, tp.Any]]:
"""Run one or more NeuralBench experiments from Python.
This is the programmatic equivalent of the ``neuralbench`` CLI.
It assembles experiment configs from the same YAML files and returns
test-metric dictionaries when running in debug mode.
Parameters
----------
device : str
Brain recording device (``"eeg"``, ``"meg"``, ``"fmri"``, ...).
task : str or list of str
Task name(s), ``"all"``, or ``"all_multi_dataset"``.
model : str or list of str or None
Predefined model name(s), ``"all"``, ``"all_classic"``, ``"all_fm"``,
``"all_baseline"`` (chance / dummy / classical sklearn pipelines),
or ``None`` (uses default model from ``config.yaml``).
dataset : str or list of str or None
Dataset variant(s) or ``"all"``. ``None`` uses the base config.
checkpoint : str or None
Path to a model checkpoint to reload.
downstream_wrapper : str or list of str or None
Downstream wrapper name(s) or ``"all"``.
grid : bool
Expand the task-specific hyperparameter grid.
debug : bool
Run locally with a reduced config (2 epochs, 5 batches).
force : bool
Force re-running experiments.
retry : bool
Retry failed experiments while keeping completed results.
prepare : bool
Run a single experiment to warm the preprocessing cache.
download : bool
Only download the dataset; do not run experiments.
plot_cached : bool
Generate plots and tables from cached results only, without
running any new experiments.
Returns
-------
list of dict
One result dict per experiment (empty when experiments are
submitted asynchronously via Slurm).
"""
logging.basicConfig(level=logging.INFO)
logging.getLogger("numexpr").setLevel(logging.WARNING)
# fontTools.subset emits very chatty INFO logs while matplotlib embeds
# font subsets into vector outputs (e.g. PDFs); mute them.
logging.getLogger("fontTools").setLevel(logging.WARNING)
logging.getLogger("fontTools.subset").setLevel(logging.WARNING)
logging.getLogger("fontTools.ttLib").setLevel(logging.WARNING)
logging.getLogger("matplotlib.font_manager").setLevel(logging.WARNING)
# exca and neuralset attach their own StreamHandlers; disable propagation
# to the root logger to avoid duplicate log lines.
logging.getLogger("exca").propagate = False
logging.getLogger("neuralset").propagate = False
if plot_cached and (force or retry or prepare):
raise ValueError(
"Cannot use force, retry, or prepare flags when plotting cached results."
)
from neuralbench.config_manager import _ensure_initialized
_ensure_initialized()
_validate_inputs(device, task, model, downstream_wrapper)
_warn_slurm_partition(debug, prepare=prepare, download=download)
# --- base config & grid ---
default_config = load_yaml_config(DEFAULTS_DIR / "config.yaml")
config = ConfDict(default_config)
default_grid = load_yaml_config(DEFAULTS_DIR / "grid.yaml")
grid_conf = ConfDict(default_grid)
if prepare or debug:
grid_conf["seed"] = [grid_conf["seed"][0]]
if checkpoint is not None:
config["pretrained_weights_fname"] = checkpoint
# --- downstream wrappers ---
if downstream_wrapper is not None:
wrappers = (
[downstream_wrapper]
if isinstance(downstream_wrapper, str)
else list(downstream_wrapper)
)
if wrappers == ["all"]:
wrappers = list(ALL_DOWNSTREAM_WRAPPERS.keys())
wrapper_configs = [ALL_DOWNSTREAM_WRAPPERS[name] for name in wrappers]
grid_conf["downstream_model_wrapper"] = wrapper_configs
# --- tasks ---
tasks = _resolve_tasks(device, task)
# --- assemble experiment configs ---
configs: list[tp.Any] = []
task_iter: tp.Iterable[str] = tasks
if prepare and len(tasks) > 1:
from tqdm import tqdm
task_iter = tqdm(tasks, desc="Preparing tasks")
for task_name in task_iter:
# Resolve models per-task so the `all` / `all_baseline` aliases pick
# only the task-appropriate sklearn baseline (via FEATURE_BASED_BY_TASK)
# instead of launching every pipeline on every task.
models = _expand_models(model, device=device, task_name=task_name)
datasets = _resolve_datasets(device, task_name, dataset)
task_configs = prepare_task_configs(
config.copy(),
grid_conf,
device,
task_name,
grid,
debug,
force,
prepare,
download,
models,
datasets,
quiet=plot_cached,
retry=retry,
)
configs.extend(task_configs)
if download:
return []
if plot_cached:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""
from neuralbench.main import BenchmarkAggregator
agg = BenchmarkAggregator(
experiments=configs,
debug=debug,
)
if not plot_cached:
agg.prepare()
results = []
if plot_cached:
logger.info("--- PREPARING GLOBAL PLOTS AND TABLES ---")
results = agg.run(cached_only=True)
return results
[docs]
def run_benchmark_cli() -> None:
"""CLI entry point for ``neuralbench``.
Parses command-line arguments and delegates to :func:`run_benchmark`.
"""
parser = argparse.ArgumentParser(
description="Run neuralbench.",
epilog=_format_datasets_epilog(),
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument(
"device",
type=str,
choices=ALL_DEVICES,
help="Brain device on which the desired task relies.",
)
parser.add_argument(
"task",
type=str,
nargs="+",
choices=["all", "all_multi_dataset"] + ALL_TASKS + ALL_UNVALIDATED_TASKS,
help=(
"Task(s) to run. Use 'all' to run all tasks, "
"'all_multi_dataset' to run only tasks with multiple dataset variants, "
"or specify one or more task names."
),
)
parser.add_argument(
"-g", "--grid", action="store_true", help="Run task-specific grid."
)
parser.add_argument(
"-d",
"--debug",
action="store_true",
help="Run in debug mode (locally and smaller config, with infra.mode='force').",
)
mode_group = parser.add_mutually_exclusive_group()
mode_group.add_argument(
"-f", "--force", action="store_true", help="Force rerunning of experiment."
)
mode_group.add_argument(
"-r",
"--retry",
action="store_true",
help="Retry failed experiments (keep completed results).",
)
parser.add_argument(
"-p",
"--prepare",
action="store_true",
help="Run single experiment to prepare cache.",
)
parser.add_argument(
"-m",
"--model",
nargs="*",
choices=["all", "all_classic", "all_fm", "all_baseline"] + ALL_MODELS,
help="Override config to use one or more predefined models. Multiple models will be run in the grid.",
)
parser.add_argument(
"-c",
"--checkpoint",
type=str,
help=(
"Path to a model checkpoint to reload. If this follows the format "
"`wandb:entity/project/grid`, it will be used to find all available checkpoint paths "
"for a specific wandb grid."
),
)
parser.add_argument(
"-w",
"--downstream-wrapper",
nargs="*",
choices=["all"] + list(ALL_DOWNSTREAM_WRAPPERS.keys()),
help="Override/add a model wrapper for the downstream tasks.",
)
parser.add_argument(
"--download",
action="store_true",
help="Download the study. The experiment(s) will not be run.",
)
parser.add_argument(
"--pdb",
action="store_true",
help="Launch pdb on exception.",
)
parser.add_argument(
"--plot-cached",
action="store_true",
help="Plot from cached results only, without running any experiments.",
)
parser.add_argument(
"--dataset",
type=str,
default=None,
help=(
"Specify a dataset variant for the task. "
"Use 'all' to run on all available datasets. "
"If provided, will load dataset-specific overrides from datasets/{dataset}.yaml "
"and merge them with the base config.yaml. "
"Example: --dataset steyrl2016 or --dataset all"
),
)
args = parser.parse_args()
try:
run_benchmark(
device=args.device,
task=args.task,
model=args.model,
dataset=args.dataset,
checkpoint=args.checkpoint,
downstream_wrapper=args.downstream_wrapper,
grid=args.grid,
debug=args.debug,
force=args.force,
retry=args.retry,
prepare=args.prepare,
download=args.download,
plot_cached=args.plot_cached,
)
except Exception:
if not args.pdb:
raise
import pdb
tb = sys.exc_info()[2]
traceback.print_exc()
pdb.post_mortem(tb)
if __name__ == "__main__":
run_benchmark_cli()