#
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#
import torch
import time
import torch
from rlstructures import DictTensor
[docs]class VecEnv:
"""
An VecEnvironment corresponds to multiple 'gym' environments (i.e a batch)
that are running simultaneously.
At each timestep, upon the B environments, a subset B' of envs are running
(since some envs may have stopped).
So each observation returned by the VecEnv is a DictTensor of size B'. To
mark which environments that are still running, the observation is returned
with a mapping vector of size B'. e.g [0,2,5] means that the observation 0
corresponds to the env 0, the observation 1 corresponds to env 2, and the
observation 3 corresponds to env 5.
Finally, when running a step (at time t) method (over B' running envs), the
agent has to provide an action (DictTensor) of size B'. The VecEnv will return
the next observation (time t+1) (size B'). But some of the B' envs may have
stopped at t+1, such that actually only B'' envs are still running. The
step method will thus also return a B'' observation (and corresponding
mapping).
The return of the step function is thus:
((DictTensor of size B', tensor of size B'),
(Dicttensor of size B'', mapping vector if size B''))
"""
def __init__(self):
pass
[docs] def reset(self, env_info: DictTensor = None):
"""reset the environments instances
:param env_info: a DictTensor of size n_envs, such that each value will be transmitted to each environment instance
:type env_info: DictTensor, optional
"""
pass
[docs] def step(
self, policy_output: DictTensor
) -> [[DictTensor, torch.Tensor], [DictTensor, torch.Tensor]]:
"""Execute one step over alll the running environment instances
:param policy_output: the output given by the policy
:type policy_output: DictTensor
:return: see general description
:rtype: [[DictTensor,torch.Tensor],[DictTensor,torch.Tensor]]
"""
raise NotImplementedError
[docs] def close(self):
"""Terminate the environment"""
raise NotImplementedError
[docs] def n_envs(self) -> int:
"""Returns the number of environment instances contained in this env
:rtype: int
"""
return self.reset()[0].n_elems()