moolib.Batcher

class moolib.Batcher

A auxiliary class to asynchronously batch tensors into an chosen batch size on a certain device.

This class will often be used in an RL setting, see examples/impala/impala.py, with batching happening in a background thread. This class is awaitable with asyncio.

Example

The batcher can take a series of data points and return a batched version:

batcher = Batcher(size=32, device="cuda:0", dim=0)

datapoint = {
  "batchme": torch.ones([2, 7]),
  "pliss": torch.randn([3])
}

while True:
    batcher.cat(data_point)
    while not batcher.empty():
        mb = learn_batcher.get()

        mb["batchme"]  # [32, 2, 7]
        mb["pliss"]  # [32, 3]
__init__()

Init.

Parameters
  • size (int) – the batch size to batch to.

  • device (str) – the device to batch tensors on.

  • dim (int) – the tensor dim to perform the operations along.

Methods

__init__

Init.

cat

Batch the tensors by concatenating along the target dim.

empty

Returns True if there are no batched tensors to get.

get

Return a batched tensor.

stack

Batch the tensors by stacking along the target dim.

cat()

Batch the tensors by concatenating along the target dim.

Parameters

tensors – the tensors to concatenate.

empty()

Returns True if there are no batched tensors to get.

Returns

Return type

bool

get()

Return a batched tensor.

This is a blocking call.

Returns

batched tensor

stack()

Batch the tensors by stacking along the target dim.

Parameters

tensors – the tensors to stack.