Source code for neuraltrain.augmentations.augmentations

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.


import numpy as np
import pydantic
import torch
from torch import nn


[docs] class ChannelsDropoutConfig(pydantic.BaseModel): """Configuration for braindecode's ``ChannelsDropout`` augmentation. Parameters ---------- probability : float Probability of applying the augmentation to a given example. p_drop : float Fraction of channels to drop. """ probability: float p_drop: float model_config = pydantic.ConfigDict(protected_namespaces=(), extra="forbid") def build(self) -> nn.Module: from braindecode.augmentation import ChannelsDropout return ChannelsDropout( probability=self.probability, p_drop=self.p_drop, )
[docs] class FrequencyShiftConfig(pydantic.BaseModel): """Configuration for braindecode's ``FrequencyShift`` augmentation. Parameters ---------- probability : float Probability of applying the augmentation to a given example. sfreq : float Sampling frequency of the recording, in Hz. max_delta_freq : float Maximum frequency shift (in Hz) applied to the signal. """ probability: float sfreq: float max_delta_freq: float model_config = pydantic.ConfigDict(protected_namespaces=(), extra="forbid") def build(self) -> nn.Module: from braindecode.augmentation import FrequencyShift return FrequencyShift( probability=self.probability, sfreq=self.sfreq, max_delta_freq=self.max_delta_freq, )
[docs] class GaussianNoiseConfig(pydantic.BaseModel): """Configuration for braindecode's ``GaussianNoise`` augmentation. Parameters ---------- probability : float Probability of applying the augmentation to a given example. std : float Standard deviation of the additive Gaussian noise. """ probability: float std: float model_config = pydantic.ConfigDict(protected_namespaces=(), extra="forbid") def build(self) -> nn.Module: from braindecode.augmentation import GaussianNoise return GaussianNoise( probability=self.probability, std=self.std, )
[docs] class SmoothTimeMaskConfig(pydantic.BaseModel): """Configuration for braindecode's ``SmoothTimeMask`` augmentation. Parameters ---------- probability : float Probability of applying the augmentation to a given example. mask_len_samples : int Length (in samples) of the time block to mask. """ probability: float mask_len_samples: int model_config = pydantic.ConfigDict(protected_namespaces=(), extra="forbid") def build(self) -> nn.Module: from braindecode.augmentation import SmoothTimeMask return SmoothTimeMask( probability=self.probability, mask_len_samples=self.mask_len_samples, )
[docs] class BandstopFilterFFTConfig(pydantic.BaseModel): """Configuration for :class:`BandstopFilterFFT`. Parameters ---------- sfreq : float Sampling frequency of the recording, in Hz. bandwidth : float Bandwidth (in Hz) of the bandstop filter. Must be less than half the sampling frequency. """ sfreq: float bandwidth: float model_config = pydantic.ConfigDict(protected_namespaces=(), extra="forbid") def build(self) -> nn.Module: return BandstopFilterFFT( sfreq=self.sfreq, bandwidth=self.bandwidth, )
[docs] class BandstopFilterFFT(nn.Module): """ Bandstop data augmentation, applying a bandstop filter to the data using Fourier transform. Parameters ---------- sfreq: Sampling frequency of the recording bandwidth: Bandwidth of the bandstop filter """ def __init__( self, sfreq: float, bandwidth: float, ): super().__init__() if bandwidth * 2 > sfreq: raise ValueError( "Bandwidth needs to be smaller than half of sampling frequency." ) self.sfreq = sfreq self.bandwidth = bandwidth
[docs] def forward(self, x: torch.Tensor): ffted = torch.fft.rfft( x, ) n_bins = int(np.round(self.bandwidth * 2 * ffted.shape[-1] / self.sfreq)) i_bins = torch.randint(ffted.shape[-1] - n_bins, (len(ffted),)) for i_example, i_bin in enumerate(i_bins): ffted[i_example, :, i_bin : i_bin + n_bins] = 0 iffted = torch.fft.irfft(ffted) return iffted
[docs] class TrivialBrainAugmentConfig(pydantic.BaseModel): sfreq: float min_max_ch_drop: tuple[float, float] = (0.05, 0.4) min_max_gauss_noise: tuple[float, float] = (0.01, 0.3) min_max_time_mask: tuple[float, float] = (2, 32) min_max_bandstop: tuple[float, float] = (1, 8) min_max_freq_shift: tuple[float, float] = (-1, 1) model_config = pydantic.ConfigDict(protected_namespaces=(), extra="forbid") def build(self) -> nn.Module: return TrivialBrainAugment(self)
[docs] class TrivialBrainAugment(nn.Module): """ Inspired by TrivialAugment [1], sample augmentations and strength randomly on each minibatch/forward pass. Parameters ---------- config: Configuration that contains values for: sfreq: Sampling frequency of the recording min_max_ch_drop: Min/Max for linspace of channel dropout probabilities min_max_gauss_noise: Min/Max for linspace of gaussian noise standard deviation min_max_time_mask: Min/Max for logspace of length of timeblock to be masked min_max_bandstop: Min/Max for logspace of frequency width of bandstop filter min_max_freq_shift: Min/Max for linspace of frequency shift References ---------- .. [1] Mueller, Samuel and Hutter, Frank. "TrivialAugment: Tuning-free Yet State-of-the-Art Data Augmentation" """ def __init__(self, cfg: TrivialBrainAugmentConfig): super().__init__() self.cfg = cfg
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: from braindecode.augmentation import ( ChannelsDropout, FrequencyShift, GaussianNoise, SmoothTimeMask, ) num_strengths = 32 strength = int(torch.randint(0, num_strengths, (1,)).item()) transforms = [ ChannelsDropout( 1, p_drop=np.linspace( self.cfg.min_max_ch_drop[0], self.cfg.min_max_ch_drop[1], num=num_strengths, )[strength], ), GaussianNoise( 1, std=np.linspace( self.cfg.min_max_gauss_noise[0], self.cfg.min_max_gauss_noise[1], num=num_strengths, )[strength], ), SmoothTimeMask( 1, mask_len_samples=int( np.logspace( np.log2(self.cfg.min_max_time_mask[0]), np.log2(self.cfg.min_max_time_mask[1]), base=2, num=num_strengths, ).round()[strength] ), ), BandstopFilterFFT( sfreq=self.cfg.sfreq, bandwidth=np.logspace( np.log2(self.cfg.min_max_bandstop[0]), np.log2(self.cfg.min_max_bandstop[1]), base=2, num=num_strengths, )[strength], ), FrequencyShift( 1, sfreq=self.cfg.sfreq, max_delta_freq=np.linspace( self.cfg.min_max_freq_shift[0], self.cfg.min_max_freq_shift[1], num=num_strengths, )[strength], ), ] i_transform = int(torch.randint(0, len(transforms), (1,)).item()) transform = transforms[i_transform] return transform(x)