neuraltrain.models.common.BahdanauAttention

class neuraltrain.models.common.BahdanauAttention(input_size, hidden_size)[source][source]

Bahdanau attention from [1].

Implementation inspired from pytorch’s seq2seq tutorial: https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html#the-decoder

forward(keys, queries=None)[source][source]
Parameters:
  • keys – Key tensor of shape (batch_size, n_features, n_times).

  • queries – Optional query tensor of shape (batch_size, n_features, n_times). If None, only keys are used.