# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Utility functions."""
import logging
import typing as tp
from copy import copy
from hashlib import sha1
from pathlib import Path
import lightning.pytorch as pl
import numpy as np
import torch
from sklearn.utils import compute_class_weight
from torch import nn
import neuralset as ns
from neuralset.dataloader import SegmentDataset
LOGGER = logging.getLogger(__name__)
def model_hash(model: nn.Module) -> str:
hasher = sha1()
for p in model.parameters():
hasher.update(p.data.cpu().numpy().tobytes())
return hasher.hexdigest()
_PACKAGE_DIR = Path(__file__).resolve().parent
[docs]
def load_checkpoint(
brain_model: nn.Module,
checkpoint_path: str | Path,
logger: logging.Logger,
) -> nn.Module:
"""Load checkpoint through state_dicts.
Note
----
While pytorch-lightning exposes ways to do this, we implement checkpoint loading
directly for more fine-grained control.
"""
checkpoint_path = Path(checkpoint_path).expanduser()
suffix = checkpoint_path.suffix
if checkpoint_path.is_absolute():
checkpoint_path = checkpoint_path.resolve()
else:
checkpoint_path = (_PACKAGE_DIR / checkpoint_path).resolve()
assert suffix in (
".ckpt",
".pth",
".pt",
".safetensors",
), f"Expected .ckpt, .pth, .pt or .safetensors extension but got {checkpoint_path}"
if not checkpoint_path.exists():
raise FileNotFoundError(f"Checkpoint path {checkpoint_path} not found.")
logger.info(f"Reloading checkpoint from {checkpoint_path}")
logger.info(f"Initial model hash: {model_hash(brain_model)}")
if suffix == ".safetensors":
from safetensors.torch import load_file
checkpoint = load_file(checkpoint_path, device="cpu")
# For Braindecode models with explicit channel mapping (e.g. LUNA)s
if (mapping := getattr(brain_model, "mapping", None)) is not None:
checkpoint = {mapping.get(k, k): v for k, v in checkpoint.items()}
else:
checkpoint = torch.load(checkpoint_path, weights_only=True, map_location="cpu")
# Load checkpoint and update state dict
if "state_dict" in checkpoint:
checkpoint = checkpoint["state_dict"] # type: ignore[assignment]
stripped_state_dict = {}
for name, v in checkpoint.items():
# PyTorch Lightning uses "model." in front of each layer
if name.startswith("model."):
name = name.replace("model.", "")
stripped_state_dict[name] = v
model_dict = brain_model.state_dict()
# When the model is a wrapper (e.g. _LunaEncoderWrapper stores the inner
# model as self.model), state dict keys are prefixed with "model." while
# checkpoint keys are not. Try to auto-prefix to match.
if not (set(stripped_state_dict) & set(model_dict)):
for prefix in ["model."]:
prefixed = {f"{prefix}{k}": v for k, v in stripped_state_dict.items()}
if set(prefixed) & set(model_dict):
logger.info(
"Auto-prefixed checkpoint keys with %r to match model state dict.",
prefix,
)
stripped_state_dict = prefixed
break
missing = set(model_dict) - set(stripped_state_dict)
additional = set(stripped_state_dict) - set(model_dict)
logger.info(f"Missing keys in checkpoint: {sorted(missing)}")
logger.info(f"Additional keys in checkpoint: {sorted(additional)}")
stripped_state_dict = {
k: v for k, v in stripped_state_dict.items() if k in model_dict
}
keys_to_remove = []
for k, v in stripped_state_dict.items():
if model_dict[k].size() != v.size():
logger.info(
f"Size mismatch for {k}, checkpoint has shape {v.size()} and current model has shape {model_dict[k].size()}."
)
keys_to_remove.append(k)
for k in keys_to_remove:
stripped_state_dict.pop(k, None)
model_dict.update(stripped_state_dict)
brain_model.load_state_dict(model_dict)
logger.info(f"Loaded model hash: {model_hash(brain_model)}")
return brain_model
def get_targets_from_dataset(dataset: SegmentDataset) -> torch.Tensor:
feat_dataset = copy(dataset)
# Drop neuro as it takes the most time to process
feat_dataset.extractors = {"target": feat_dataset.extractors["target"]}
return feat_dataset.load_all().data["target"]
def get_neuro_and_targets_from_dataset(
dataset: SegmentDataset,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Materialise ``(neuro, target)`` tensors for the whole dataset.
Used by fit-once baselines (e.g. :class:`SklearnBaseline`) that need
the full training set as a single NumPy array rather than mini-batches.
Other extractors (e.g. ``channel_positions``, ``subject_id``) are
dropped to minimise memory and load time.
"""
feat_dataset = copy(dataset)
feat_dataset.extractors = {
"neuro": feat_dataset.extractors["neuro"],
"target": feat_dataset.extractors["target"],
}
data = feat_dataset.load_all().data
return data["neuro"], data["target"]
[docs]
def compute_class_weights_from_dataset(
train_dataset: SegmentDataset,
logger: logging.Logger,
task: tp.Literal["multiclass", "multilabel", "auto"] = "auto",
) -> tuple[dict[str, torch.Tensor], torch.Tensor]:
"""Compute class weights from training dataset for handling class imbalance."""
targets = get_targets_from_dataset(train_dataset)
if isinstance(train_dataset.extractors["target"], ns.extractors.LabelEncoder):
class_mapping = list(train_dataset.extractors["target"]._label_to_ind.keys())
else:
class_mapping = [str(i) for i in range(int(targets.shape[-1]))]
logger.info("Computing class weights...")
if task == "auto":
task = "multiclass"
if targets.ndim == 2:
n_classes_per_example = targets.sum(dim=1)
if (n_classes_per_example > 1).any() or (n_classes_per_example == 0).any():
task = "multilabel"
loss_kwargs = {}
if task == "multilabel":
if targets.ndim != 2:
raise ValueError("Expected 2D targets for multilabel")
n_classes_per_example = targets.sum(dim=1)
has_multi = (n_classes_per_example > 1).any()
has_zero = (n_classes_per_example == 0).any()
if not (has_multi or has_zero):
raise ValueError("Expected some examples with multiple or zero classes")
y_true = targets.clamp(max=1.0).bool().squeeze(dim=1)
pos_weight = (~y_true).sum(dim=0) / y_true.sum(dim=0) # n_negatives / n_positives
pos_weight = torch.nan_to_num(pos_weight, posinf=1.0)
pos_weight_dict = dict(zip(class_mapping, pos_weight.tolist()))
logger.info(f"Positive class weights: {pos_weight_dict}")
loss_kwargs["pos_weight"] = pos_weight # For BCEWithLogitsLoss
elif task == "multiclass":
if targets.ndim == 2:
n_classes = targets.shape[1]
if n_classes == 1:
y_true = targets.squeeze(dim=1)
else:
y_true = targets.argmax(dim=-1)
else:
y_true = targets
n_classes = int(y_true.max().item()) + 1
observed_classes = np.unique(y_true)
observed_weights = torch.tensor(
compute_class_weight(
class_weight="balanced",
classes=observed_classes,
y=y_true.tolist(),
)
).float()
if len(observed_classes) < n_classes:
class_weights = torch.ones(n_classes, dtype=torch.float32)
for cls, w in zip(observed_classes, observed_weights):
class_weights[int(cls)] = w
else:
class_weights = observed_weights
class_weights_dict = dict(zip(class_mapping, class_weights.tolist()))
logger.info(f"Class weights: {class_weights_dict}")
loss_kwargs["weight"] = class_weights
return loss_kwargs, y_true
[docs]
def make_weighted_sampler(
dataset: SegmentDataset,
logger: logging.Logger,
) -> torch.utils.data.WeightedRandomSampler:
"""Create a weighted random sampler for the given dataset to handle class imbalance."""
loss_kwargs, y_true = compute_class_weights_from_dataset(
dataset,
logger=logger,
task="multiclass",
)
# TODO: Adapt to work with multilabel case as well
weights = loss_kwargs["weight"][y_true]
sampler = torch.utils.data.WeightedRandomSampler(
weights=weights.tolist(),
num_samples=len(weights),
replacement=True,
)
return sampler
[docs]
class TrainerConfig(ns.BaseModel):
"""Joint configuration for Trainer and some callbacks."""
n_epochs: int = 100
enable_progress_bar: bool = True
log_every_n_steps: int = 20
fast_dev_run: bool = False
gradient_clip_val: float = 0.0
limit_train_batches: int | None = None
limit_val_batches: int | None = None
num_sanity_val_steps: int = 2
accumulate_grad_batches: int = 1
# Hardware
strategy: str = "auto"
precision: str = "32-true"
accelerator: str = "auto"
devices: int = 1
num_nodes: int = 1
# Callbacks
patience: int = 5
monitor: str = "val/loss"
mode: str = "min"
[docs]
def build(
self,
logger,
callbacks,
accelerator: str | None = None,
devices: int | None = None,
num_nodes: int | None = None,
) -> pl.Trainer:
return pl.Trainer(
strategy=self.strategy,
precision=self.precision, # type: ignore[arg-type]
accelerator=self.accelerator if accelerator is None else accelerator,
devices=self.devices if devices is None else devices,
num_nodes=self.num_nodes if num_nodes is None else num_nodes,
gradient_clip_val=self.gradient_clip_val,
limit_train_batches=self.limit_train_batches,
limit_val_batches=self.limit_val_batches,
max_epochs=self.n_epochs,
enable_progress_bar=self.enable_progress_bar,
log_every_n_steps=self.log_every_n_steps,
num_sanity_val_steps=self.num_sanity_val_steps,
fast_dev_run=self.fast_dev_run,
accumulate_grad_batches=self.accumulate_grad_batches,
logger=logger,
callbacks=callbacks,
enable_model_summary=False,
)