Shortcuts

Source code for xformers.components.positional_embedding.sine

# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.


# Silence Mypy errors in this file.
# type: ignore

import math

import torch

from xformers.components.positional_embedding import (
    PositionEmbedding,
    PositionEmbeddingConfig,
    register_positional_embedding,
)


[docs]@register_positional_embedding("sine", PositionEmbeddingConfig) class SinePositionalEmbedding(PositionEmbedding): def __init__(self, dim_model: int, *args, **kwargs): super().__init__() self.dim_model = dim_model
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: seq_len = x.shape[1] pos = ( torch.arange(0, seq_len, device=x.device, dtype=torch.float32) .unsqueeze(1) .repeat(1, self.dim_model) ) dim = ( torch.arange(0, self.dim_model, device=x.device, dtype=torch.float32) .unsqueeze(0) .repeat(seq_len, 1) ) div = torch.exp(-math.log(10000) * (2 * (dim // 2) / self.dim_model)) pos *= div pos[:, 0::2] = torch.sin(pos[:, 0::2]) pos[:, 1::2] = torch.cos(pos[:, 1::2]) output = x.unsqueeze(-1) if x.ndim == 2 else x return output + pos.unsqueeze(0)