Incremental State¶
- class fairseq2.nn.IncrementalState[source]¶
Bases:
ABC
Holds the state of a module during incremental decoding.
Incremental decoding is a special mode at inference time where the module only receives an input corresponding to the previous output and must produce the next output incrementally. Thus the module must cache any long-term state that is needed about the sequence.
- abstract reorder(new_order: Tensor) None [source]¶
Rearranges the state according to a new batch order.
This will be called when the order of the batch has changed. A typical use case is beam search, where the batch order changes between steps based on the selection of beams.
- Parameters:
new_order – The new order of the batch. It is frequently used with
torch.index_select()
to rearrange the state tensors. Shape: \((N)\), where \(N\) is the batch size.
- final class fairseq2.nn.IncrementalStateBag(max_num_steps: int, *, capacity_increment: int | None = 16)[source]¶
Bases:
object
Holds the module states during incremental decoding.
- Parameters:
max_num_steps – The maximum number of steps to take.
capacity_increment – The sequence length capacity of state tensors will be incremented by multiples of this value. If
None
, state tensors will be preallocated with a capacity ofmax_num_steps
.
- increment_step_nr(value: int = 1) None [source]¶
Increments the step number.
This method should be called after every decoding step. It is used by modules to keep track of the position in the sequence.
- Parameters:
value – The value by which to increment the step number.
- maybe_get_state(m: Module, kls: type[T]) T | None [source]¶
Gets the state of
m
if present in the bag.- Parameters:
m – The module.
kls – The expected
type
of the state. If the type of the state in the bag does not matchkls
,None
will be returned.
- Returns:
The state of the module.
- set_state(m: Module, state: IncrementalState) None [source]¶
Sets the state of
m
.- Parameters:
m – The module.
state – The state to store.
- reorder(new_order: Tensor) None [source]¶
Reorders the module states.
See
IncrementalState.reorder()
for more information.