Shortcuts

Source code for xformers.components.feedforward.mlp

# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.


from dataclasses import dataclass

import torch
import torch.nn as nn

from xformers.components import Activation, build_activation
from xformers.components.feedforward import Feedforward, FeedforwardConfig

from . import register_feedforward


@dataclass
class MlpConfig(FeedforwardConfig):
    hidden_layer_multiplier: int
    bias: bool


[docs]@register_feedforward("MLP", MlpConfig) class MLP(Feedforward): def __init__( self, dim_model: int, dropout: float, activation: Activation, hidden_layer_multiplier: int, bias: bool = True, *args, **kwargs, ): super().__init__() dim_mlp = hidden_layer_multiplier * dim_model self.mlp = nn.Sequential( nn.Linear(in_features=dim_model, out_features=dim_mlp, bias=bias), build_activation(activation), nn.Dropout(dropout), nn.Linear(in_features=dim_mlp, out_features=dim_model, bias=bias), nn.Dropout(dropout), )
[docs] def forward(self, inputs: torch.Tensor) -> torch.Tensor: return self.mlp(inputs)