Shortcuts

Source code for xformers.components.feedforward.base

# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.


from abc import ABCMeta, abstractmethod
from dataclasses import asdict, dataclass
from typing import Optional, Type, TypeVar

import torch.nn as nn

from xformers.components import Activation

Self = TypeVar("Self", bound="Feedforward")


@dataclass
class FeedforwardConfig:
    name: str
    dim_model: int
    dropout: float
    activation: Activation


# Define the common interface, every feedforward block needs to derive from it
[docs]class Feedforward(nn.Module, metaclass=ABCMeta): @abstractmethod def __init__( self, dim_model: Optional[int] = None, dropout: Optional[float] = None, activation: Optional[Activation] = None, *args, **kwargs, ): super().__init__() # This feedforward requires a CUDA accelerator self.requires_cuda = False # This feedforward requires a context length which is squared, often due to 2D pooling self.requires_squared_context = False
[docs] @classmethod def from_config(cls: Type[Self], config: FeedforwardConfig) -> Self: # Generate the class inputs from the config fields = asdict(config) # Skip all Nones so that default values are used fields = {k: v for k, v in fields.items() if v is not None} return cls(**fields)