Incremental State¶
- class fairseq2.nn.IncrementalState[source]¶
Bases:
ABC
Holds the state of a module during incremental decoding.
Incremental decoding is a special mode at inference time where the module only receives an input corresponding to the previous output and must produce the next output incrementally. Thus the module must cache any long-term state that is needed about the sequence.
- abstract reorder(new_order)[source]¶
Rearranges the state according to a new batch order.
This will be called when the order of the batch has changed. A typical use case is beam search, where the batch order changes between steps based on the selection of beams.
- Parameters:
new_order (Tensor) – The new order of the batch. It is frequently used with
torch.index_select()
to rearrange the state tensors. Shape: \((N)\), where \(N\) is the batch size.
- final class fairseq2.nn.IncrementalStateBag(max_num_steps, *, capacity_increment=16)[source]¶
Bases:
object
Holds the module states during incremental decoding.
- Parameters:
- increment_step_nr(value=1)[source]¶
Increments the step number.
This method should be called after every decoding step. It is used by modules to keep track of the position in the sequence.
- Parameters:
value (int) – The value by which to increment the step number.
- set_state(m, state)[source]¶
Sets the state of
m
.- Parameters:
m (Module) – The module.
state (IncrementalState) – The state to store.
- reorder(new_order)[source]¶
Reorders the module states.
See
IncrementalState.reorder()
for more information.