Incremental State

class fairseq2.nn.IncrementalState[source]

Bases: ABC

Holds the state of a module during incremental decoding.

Incremental decoding is a special mode at inference time where the module only receives an input corresponding to the previous output and must produce the next output incrementally. Thus the module must cache any long-term state that is needed about the sequence.

abstract reorder(new_order)[source]

Rearranges the state according to a new batch order.

This will be called when the order of the batch has changed. A typical use case is beam search, where the batch order changes between steps based on the selection of beams.

Parameters:

new_order (Tensor) – The new order of the batch. It is frequently used with torch.index_select() to rearrange the state tensors. Shape: \((N)\), where \(N\) is the batch size.

abstract size_bytes()[source]

Returns the size of the state in bytes.

Return type:

int

abstract capacity_bytes()[source]

Returns the reserved capacity of the state in bytes.

Return type:

int

final class fairseq2.nn.IncrementalStateBag(max_num_steps, *, capacity_increment=16)[source]

Bases: object

Holds the module states during incremental decoding.

Parameters:
  • max_num_steps (int) – The maximum number of steps to take.

  • capacity_increment (int | None) – The sequence length capacity of state tensors will be incremented by multiples of this value. If None, state tensors will be preallocated with a capacity of max_num_steps.

increment_step_nr(value=1)[source]

Increments the step number.

This method should be called after every decoding step. It is used by modules to keep track of the position in the sequence.

Parameters:

value (int) – The value by which to increment the step number.

maybe_get_state(m, kls)[source]

Gets the state of m if present in the bag.

Parameters:
  • m (Module) – The module.

  • kls (type[T]) – The expected type of the state. If the type of the state in the bag does not match kls, None will be returned.

Returns:

The state of the module.

Return type:

T | None

set_state(m, state)[source]

Sets the state of m.

Parameters:
reorder(new_order)[source]

Reorders the module states.

See IncrementalState.reorder() for more information.

property step_nr: int

The current step number.

property max_num_steps: int

The maximum number of steps.

property capacity_increment: int | None

The sequence length capacity of state tensors will be incremented by multiples of this value.

size_bytes()[source]

Returns the size of the state bag in bytes.

Return type:

int

capacity_bytes()[source]

Returns the reserved capacity of the state bag in bytes.

Return type:

int