moolib.Batcher¶
- class moolib.Batcher¶
A auxiliary class to asynchronously batch tensors into an chosen batch size on a certain device.
This class will often be used in an RL setting, see
examples/impala/impala.py
, with batching happening in a background thread. This class isawaitable
with asyncio.Example
The batcher can take a series of data points and return a batched version:
batcher = Batcher(size=32, device="cuda:0", dim=0) datapoint = { "batchme": torch.ones([2, 7]), "pliss": torch.randn([3]) } while True: batcher.cat(data_point) while not batcher.empty(): mb = learn_batcher.get() mb["batchme"] # [32, 2, 7] mb["pliss"] # [32, 3]
- __init__()¶
Init.
- Parameters
size (int) – the batch size to batch to.
device (str) – the device to batch tensors on.
dim (int) – the tensor dim to perform the operations along.
Methods
Init.
Batch the tensors by concatenating along the target dim.
Returns True if there are no batched tensors to
get
.Return a batched tensor.
Batch the tensors by stacking along the target dim.
- cat()¶
Batch the tensors by concatenating along the target dim.
- Parameters
tensors – the tensors to concatenate.
- empty()¶
Returns True if there are no batched tensors to
get
.- Returns
- Return type
bool
- get()¶
Return a batched tensor.
This is a blocking call.
- Returns
batched tensor
- stack()¶
Batch the tensors by stacking along the target dim.
- Parameters
tensors – the tensors to stack.