# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import typing as tp
import torch
from torch import nn
from torch.nn.functional import conv2d
from .base import BaseModelConfig
EEG_FREQ_LIMS_HZ = [
0.1,
0.5,
4.0,
8.0,
12.0,
30.0,
60.0,
]
[docs]
class FreqBandNet(BaseModelConfig):
"""Parametrized filterbank feature extractor (sinc filters + power + log).
Parameters
----------
sfreq : float or None
Sampling frequency of the input time series, in Hz. Can also be
provided at build time.
freq_lims_hz : list of float or None
Bandpass cutoff frequencies used to initialize the sinc filters.
``len(freq_lims_hz) - 1`` filters are created. Mutually exclusive
with *n_filters_conv*.
n_filters_conv : int or None
Number of filters to initialize with log-spaced cutoffs. Mutually
exclusive with *freq_lims_hz*.
conv_kernel_len : int
Kernel length (in samples) of each sinc convolution filter.
conv_stride : int
Stride of the sinc convolution.
conv_padding : {"valid", "same"}
Padding mode for the sinc convolution.
pool_kernel_len : int
Kernel length for average-pooling after the filterbank.
flat_out : {"channels", "channels_and_time"} or None
How to flatten the filterbank output before the classifier.
``"channels"`` merges filters and channels → ``(B, F*C, T)``.
``"channels_and_time"`` flattens everything → ``(B, F*C*T)``.
n_outputs : int or None
If set, append a linear output layer with this many units.
"""
sfreq: float | None = None
freq_lims_hz: list[float] | None = EEG_FREQ_LIMS_HZ
n_filters_conv: int | None = None
conv_kernel_len: int = 65
conv_stride: int = 1
conv_padding: tp.Literal["valid", "same"] = "valid"
pool_kernel_len: int = 30
flat_out: tp.Literal["channels", "channels_and_time"] | None = None
n_outputs: int | None = None
def model_post_init(self, __context):
super().model_post_init(__context)
if not (self.freq_lims_hz is None) ^ (self.n_filters_conv is None):
raise ValueError(
"Exactly one of freq_lims_hz and n_filters_conv must be specified."
)
def build(
self,
n_in_channels: int | None = None, # Unused; for compatibility with other models
n_outputs: int | None = None,
sfreq: float | None = None,
) -> nn.Module:
sfreq = sfreq or self.sfreq
if sfreq is None:
raise ValueError("sfreq must be provided to build the model.")
return FreqBandNetModel(
n_outputs=n_outputs or self.n_outputs,
sfreq=sfreq,
config=self,
)
[docs]
class FreqBandNetModel(nn.Module):
"""Simple parametrized filterbank feature extractor (bandpass filters + power extraction +
log nonlinearity + optional MLP output head).
"""
def __init__(
self,
sfreq: float,
config: FreqBandNet,
n_outputs: int | None = None,
):
super().__init__()
self.sinc_conv = SincConv(
1,
config.n_filters_conv,
kernel_size=config.conv_kernel_len,
stride=config.conv_stride,
sfreq=sfreq,
padding=config.conv_padding,
freq_lims_hz=config.freq_lims_hz,
)
self.pool = nn.AvgPool2d(
kernel_size=(1, config.pool_kernel_len),
)
self.flat_out = config.flat_out
self.classifier = None if n_outputs is None else nn.LazyLinear(n_outputs)
[docs]
def forward(
self,
x: torch.Tensor,
eps: float = 1e-8,
) -> torch.Tensor:
"""Apply filterbank, compute log-power, and optionally classify.
Parameters
----------
x : Tensor
Input time series of shape ``(B, C, T)``.
eps : float
Small constant for numerical stability in ``log``.
"""
B, C, T = x.shape
x = x.reshape(B, 1, C, T)
x = self.sinc_conv(x) # Apply filterbank
x = self.pool(x**2) # Compute average power -> (B, F, C, T')
x = torch.log(x + eps) # Get log-power
if self.flat_out is not None:
B, F, C, Tp = x.shape
x = (
x.reshape(B, F * C, Tp)
if self.flat_out == "channels"
else x.reshape(B, -1)
)
if self.classifier is not None:
out = self.classifier(x)
else:
out = x
return out
def sinc_from_half(x_left: torch.Tensor) -> torch.Tensor:
"""Return sinc values given the left half of the indices."""
out_left = torch.sin(x_left) / x_left
out = torch.cat(
[
out_left,
torch.ones((out_left.shape[0], 1), device=x_left.device),
torch.flip(out_left, dims=[-1]),
],
dim=1,
)
return out
class SincConv(nn.Module):
"""Parametrized windowed Sinc-based convolution.
See https://arxiv.org/pdf/1808.00158
XXX Use only half the window?
"""
def __init__(
self,
in_channels: int,
out_channels: int | None,
kernel_size: int,
stride: int,
sfreq: float,
padding: tp.Literal["valid", "same"] = "valid",
min_low_hz: float = 0.1,
max_high_hz: float = 50.0,
freq_lims_hz: list[float] | None = None,
) -> None:
super().__init__()
if in_channels != 1:
raise NotImplementedError()
if kernel_size % 2 != 1:
raise ValueError(f"{kernel_size=} must be odd")
if not ((freq_lims_hz is None) ^ (out_channels is None)):
raise ValueError(
f"Exactly one of {freq_lims_hz=} and {out_channels=} must be specified."
)
self.stride = stride
self.sfreq = sfreq
self.padding = padding
# Filter cutoff parameters
if out_channels is not None:
freq_lims = (
torch.logspace(
torch.tensor(min_low_hz).log10(),
torch.tensor(max_high_hz).log10(),
steps=out_channels + 1,
base=10.0,
)
/ sfreq
)
elif freq_lims_hz is not None:
freq_lims = torch.tensor(freq_lims_hz) / sfreq
if freq_lims[-1] > 0.5:
cutoff = freq_lims[-1] * sfreq
raise ValueError(f"Cutoff {cutoff:.2f} Hz exceeds Nyquist {sfreq / 2} Hz")
self.low_freqs = nn.Parameter(freq_lims[:-1], requires_grad=True)
self.bandwidths = nn.Parameter(freq_lims[1:] - freq_lims[:-1], requires_grad=True)
# Define Hamming window
n_lin = torch.linspace(
torch.tensor(0.0),
torch.tensor(kernel_size) - 1,
steps=kernel_size,
)
_window = 0.54 - 0.46 * torch.cos(2 * torch.pi * n_lin / kernel_size)
self.register_buffer("_window", _window)
# Prepare sinc function arguments
n = (kernel_size - 1) / 2.0
_n = 2 * torch.pi * torch.arange(-n, 0).view(1, -1)
self.register_buffer("_n", _n)
def forward(
self, x: torch.Tensor, return_filters: bool = False
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""
Parameters
----------
x :
Input time series of shape (B, 1, C, T).
return_filters :
If True, also return the windowed sinc filters.
"""
# Compute filters
low_freqs = torch.abs(self.low_freqs)[:, None]
high_freqs = low_freqs + torch.abs(self.bandwidths)[:, None]
g_low = 2 * low_freqs * sinc_from_half(low_freqs @ self._n) # type: ignore
g_high = 2 * high_freqs * sinc_from_half(high_freqs @ self._n) # type: ignore
filters = (g_high - g_low) * self._window # type: ignore
# Apply filters
out = conv2d(
x,
filters[:, None, None, :],
stride=(1, self.stride),
padding=self.padding,
)
if return_filters:
return out, filters
return out