Source code for neuralbench.main

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.


import logging
import platform
import resource
import time
import typing as tp
from pathlib import Path

import lightning.pytorch as pl
import torch
import yaml
from exca import TaskInfra
from lightning.pytorch.callbacks import (
    Callback,
    EarlyStopping,
    LearningRateMonitor,
    ModelCheckpoint,
)
from lightning.pytorch.loggers import CSVLogger, WandbLogger
from lightning.pytorch.loggers.logger import DummyLogger, Logger
from pydantic import model_validator
from torch.utils.data import DataLoader
from tqdm import tqdm

import neuralset as ns
from neuraltrain.losses import BaseLoss
from neuraltrain.metrics import BaseMetric
from neuraltrain.models.base import BaseModelConfig
from neuraltrain.optimizers import LightningOptimizer
from neuraltrain.utils import (
    BaseExperiment,
    CsvLoggerConfig,
    StandardScaler,
    WandbLoggerConfig,
)

from .aggregator import (  # noqa: F401
    BenchmarkAggregator as BenchmarkAggregator,
)
from .callbacks import (
    PlotConfusionMatrix,
    PlotRegressionScatter,
    PlotRegressionVectors,
    RecordingLevelEval,
    TestFullRetrievalMetrics,
)
from .data import Data as Data  # noqa: F401
from .model_factory import build_brain_model
from .modules import DownstreamWrapper
from .pl_module import BrainModule
from .utils import TrainerConfig, compute_class_weights_from_dataset

LOGGER = logging.getLogger(__name__)


[docs] class Experiment(BaseExperiment): """Brain-modeling experiment with support for loading pretrained weights.""" task_name: str = "" # Data data: Data target_scaler: StandardScaler | None = None compute_class_weights: bool = False # Model brain_model_config: BaseModelConfig brain_model_output_size: int | None = ( None # For convenience, e.g. when making yaml configs ) pretrained_weights_fname: str | None = None downstream_model_wrapper: DownstreamWrapper | None = None # Optim trainer_config: TrainerConfig loss: BaseLoss lightning_optimizer_config: LightningOptimizer # Evaluation eval_only: bool = False metrics: list[BaseMetric] validate_before_training: bool = True test_full_metrics: list[BaseMetric] = [] test_full_retrieval_metrics: list[BaseMetric] = [] # Weights & Biases csv_config: CsvLoggerConfig | None = None wandb_config: WandbLoggerConfig | None = None # Internal properties _brain_module: pl.LightningModule | None = None _trainer: pl.Trainer | None = None _csv_logger: CSVLogger | None = None _wandb_logger: WandbLogger | None = None _n_total_params: int | None = None _n_trainable_params: int | None = None # Others seed: int = 0 delete_checkpoints_on_exit: bool = True infra: TaskInfra = TaskInfra(version="1") dummy: dict[str, tp.Any] = {} # Useful to avoid overwriting experiments between grids brain_model_name: str = "" @model_validator(mode="after") def _populate_brain_model_name(self) -> "Experiment": """Auto-populate brain_model_name from brain_model_config class when unset.""" if not self.brain_model_name: self.brain_model_name = type(self.brain_model_config).__name__ return self @model_validator(mode="after") def _validate_metrics_num_classes(self) -> "Experiment": """Guard against stale num_classes in metrics after cross-file YAML merging. YAML anchors are resolved at parse time within a single file. When a dataset override changes ``brain_model_output_size`` without redefining ``metrics``, the inherited metric dicts still contain the old ``num_classes`` value. This validator catches such mismatches early. """ if self.brain_model_output_size is None: return self for metric in [*self.metrics, *self.test_full_metrics]: kwargs = getattr(metric, "kwargs", None) if kwargs is None: continue for key in ("num_classes", "num_labels"): value = kwargs.get(key) if value is not None and value != self.brain_model_output_size: raise ValueError( f"Metric '{metric.log_name}' has {key}={value} but " f"brain_model_output_size={self.brain_model_output_size}. " f"Dataset override likely changed brain_model_output_size " f"without redefining metrics." ) return self
[docs] def prepare_pl_module( self, train_loader: DataLoader, val_loader: DataLoader | None = None, ) -> None: brain_model, self._n_total_params, self._n_trainable_params = build_brain_model( brain_model_config=self.brain_model_config, downstream_model_wrapper=self.downstream_model_wrapper, pretrained_weights_fname=self.pretrained_weights_fname, train_loader=train_loader, val_loader=val_loader, wandb_logger=self._wandb_logger, ) if self.target_scaler is not None: assert hasattr(train_loader.dataset, "extractors") neuro_extractor = train_loader.dataset.extractors.pop("neuro") for batch in tqdm(train_loader, "Fitting target scaler"): self.target_scaler.partial_fit(batch.data["target"]) if self.target_scaler._n_samples_seen > 5e5: break train_loader.dataset.extractors["neuro"] = neuro_extractor loss_kwargs: dict[str, tp.Any] = {} if self.compute_class_weights: loss_kwargs, _ = compute_class_weights_from_dataset( train_loader.dataset, # type: ignore[arg-type] task=( "multiclass" if self.loss.__class__.__name__ == "CrossEntropyLoss" else "multilabel" ), logger=LOGGER, ) self._brain_module = BrainModule( model=brain_model, target_scaler=self.target_scaler, loss=self.loss.build(**loss_kwargs), lightning_optimizer_config=self.lightning_optimizer_config, metrics={metric.log_name: metric.build() for metric in self.metrics}, test_full_metrics={ metric.log_name: metric.build() for metric in self.test_full_metrics }, test_full_retrieval_metrics={ metric.log_name: metric.build() for metric in self.test_full_retrieval_metrics }, ) pl.seed_everything(self.seed)
[docs] def fit( self, trainer: pl.Trainer, train_loader: DataLoader, valid_loader: DataLoader, ) -> None: msg = "Prepare the BrainModule first with self.prepare_pl_module()" assert self._brain_module is not None, msg if self.validate_before_training: LOGGER.info("Validating once before starting training...") trainer.validate(model=self._brain_module, dataloaders=[valid_loader]) # Train model trainer.fit( model=self._brain_module, train_dataloaders=train_loader, val_dataloaders=valid_loader, )
[docs] def setup_wandb_logger( self, wandb_config: WandbLoggerConfig, savedir: str, ) -> WandbLogger: """Setup wandb logger and launch initialization.""" import wandb wandb.login(host=wandb_config.host) logger = wandb_config.build( save_dir=savedir, xp_config=self.model_dump(), ) try: logger.experiment.config["_dummy"] = None # To launch initialization except TypeError: pass # Crashes if called in a second process, e.g. with DDP return logger
[docs] def setup_run(self): """Setup paths and wandb logger.""" savedir = self.infra.uid_folder() LOGGER.info(f"UID folder: {savedir}") # Save full config as yaml if not savedir.exists(): savedir.mkdir(parents=True, exist_ok=False) with open(savedir / "config.yaml", "w") as outfile: yaml.dump(self.model_dump(), outfile, indent=4, default_flow_style=False) if self.wandb_config is not None: self._wandb_logger = self.setup_wandb_logger(self.wandb_config, str(savedir)) if self.csv_config is not None: self._csv_logger = self.csv_config.build(save_dir=savedir)
[docs] def setup_trainer(self, is_test: bool = False) -> pl.Trainer: """Create callbacks and setup Trainer.""" callbacks: list[Callback] = [] if "confusion_matrix" in [metric.log_name for metric in self.metrics]: labels: list[str] | None = None if isinstance(self.data.target, ns.extractors.LabelEncoder): ind_to_label: dict[int, str] = {} for lbl, idx in self.data.target._label_to_ind.items(): ind_to_label.setdefault(idx, lbl) labels = [ind_to_label[i] for i in sorted(ind_to_label)] callbacks.append(PlotConfusionMatrix(labels=labels)) if is_test: if self.test_full_metrics: callbacks.append(RecordingLevelEval()) if self.test_full_retrieval_metrics: callbacks.append( TestFullRetrievalMetrics( event_type=self.data.target.event_types, # type: ignore[arg-type] retrieval_set_sizes=(None, 250), logger=LOGGER, eval_val=False, ) ) # Add regression vector visualization for multi-dimensional outputs if ( self.brain_model_output_size is not None and self.brain_model_output_size > 1 ): loss_name = ( self.loss.name if hasattr(self.loss, "name") else self.loss.__class__.__name__ ) if loss_name in ["MSELoss", "ClipLoss"]: callbacks.append(PlotRegressionVectors(num_samples=10)) # Add scatter plot for 1D regression outputs if ( self.brain_model_output_size is not None and self.brain_model_output_size == 1 ): loss_name = ( self.loss.name if hasattr(self.loss, "name") else self.loss.__class__.__name__ ) if loss_name == "MSELoss": callbacks.append(PlotRegressionScatter()) else: callbacks.append(LearningRateMonitor(logging_interval="step")) callbacks.append( EarlyStopping( monitor=self.trainer_config.monitor, patience=self.trainer_config.patience, mode=self.trainer_config.mode, verbose=True, ) ) callbacks.append( ModelCheckpoint( dirpath=self.infra.uid_folder(), filename="best", monitor=self.trainer_config.monitor, save_last=False, mode=self.trainer_config.mode, save_weights_only=True, save_on_train_epoch_end=None, save_top_k=1, enable_version_counter=False, ) ) loggers: list[Logger] = [] if self._wandb_logger is not None: loggers.append(self._wandb_logger) if self._csv_logger is not None: loggers.append(self._csv_logger) if not loggers: loggers.append(DummyLogger()) return self.trainer_config.build( logger=loggers, callbacks=callbacks, accelerator="cpu" if self.infra.gpus_per_node == 0 else "auto", devices=1 if is_test else self.infra.gpus_per_node, num_nodes=1, )
def _test( self, loaders: dict[str, DataLoader], best_model_path: str | None, ) -> dict[str, float | None]: """Run the test phase on rank 0 with a dedicated trainer.""" tester = self.setup_trainer(is_test=True) return dict( tester.test( self._brain_module, dataloaders=loaders["test"], ckpt_path=best_model_path, )[0] ) def _cleanup(self, trainer: pl.Trainer) -> None: """Delete checkpoint and finalize W&B.""" if ( self.delete_checkpoints_on_exit and not self.eval_only and hasattr(trainer.checkpoint_callback, "best_model_path") ): best_model_path = getattr( trainer.checkpoint_callback, "best_model_path", None ) assert best_model_path is not None Path(best_model_path).unlink(missing_ok=True) LOGGER.info("Deleted checkpoint: %s", best_model_path) if self._wandb_logger is not None: import wandb wandb.finish() @infra.apply( exclude_from_cache_uid=( "wandb_config", "csv_config", "brain_model_name", ) ) def run(self) -> dict[str, tp.Any]: """Execute the full experiment lifecycle: setup, train, test, cleanup. Returns a dict of test metrics (e.g. ``{"test/bal_acc": 0.85, ...}``) plus ``n_total_params`` and ``n_trainable_params``. """ logging.basicConfig(level=logging.INFO) logging.getLogger("numexpr").setLevel(logging.WARNING) logging.getLogger("fontTools").setLevel(logging.WARNING) logging.getLogger("fontTools.subset").setLevel(logging.WARNING) logging.getLogger("fontTools.ttLib").setLevel(logging.WARNING) logging.getLogger("matplotlib.font_manager").setLevel(logging.WARNING) logging.getLogger("exca").propagate = False logging.getLogger("neuralset").propagate = False self.setup_run() loaders = self.data.prepare() trainer = self.setup_trainer() self.prepare_pl_module(loaders["train"], loaders.get("val")) test_results: dict[str, tp.Any] = {} training_time_s: float | None = None peak_gpu_memory_mb: float | None = None peak_cpu_memory_mb: float | None = None if not self.eval_only: if torch.cuda.is_available(): torch.cuda.reset_peak_memory_stats() t0 = time.perf_counter() self.fit(trainer, loaders["train"], loaders["val"]) training_time_s = time.perf_counter() - t0 if torch.cuda.is_available(): peak_gpu_memory_mb = torch.cuda.max_memory_allocated() / (1024**2) rusage = resource.getrusage(resource.RUSAGE_SELF) # macOS reports bytes, Linux reports KB divisor = 1024 if platform.system() != "Darwin" else 1024**2 peak_cpu_memory_mb = rusage.ru_maxrss / divisor if isinstance(trainer.checkpoint_callback, ModelCheckpoint): best_model_path = trainer.checkpoint_callback.best_model_path assert best_model_path is not None best_ckpt = torch.load( best_model_path, map_location=torch.device("cpu"), weights_only=True, ) best_epoch = best_ckpt["epoch"] LOGGER.info("Best epoch: %i", best_epoch) for logger in [self._wandb_logger, self._csv_logger]: if logger is not None: logger.log_metrics({"best_epoch": best_epoch}) else: best_model_path = None if ( getattr(self.infra, "gpus_per_node", 0) > 1 and torch.distributed.is_initialized() ): torch.distributed.destroy_process_group() if trainer.global_rank == 0: test_results.update(self._test(loaders, best_model_path)) self._cleanup(trainer) test_results.update( { "n_total_params": self._n_total_params, "n_trainable_params": self._n_trainable_params, "training_time_s": training_time_s, "peak_gpu_memory_mb": peak_gpu_memory_mb, "peak_cpu_memory_mb": peak_cpu_memory_mb, } ) return test_results
# BenchmarkAggregator lives in aggregator.py and forward-references Experiment # (under TYPE_CHECKING to avoid a circular import). Pydantic needs the real # class at validation time, so we rebuild the schema here once both are defined. BenchmarkAggregator.model_rebuild(_types_namespace={"Experiment": Experiment})