Source code for rlstructures.core

#
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#


from __future__ import annotations
import torch
from typing import Iterable, List, Dict


[docs]def masked_tensor(tensor0, tensor1, mask): """Compute a tensor by combining two tensors with a mask :param tensor0: a Bx(N) tensor :type tensor0: torch.Tensor :param tensor1: a Bx(N) tensor :type tensor1: torch.Tensor :param mask: a B tensor :type mask: torch.Tensor :return: (1-m) * tensor 0 + m *tensor1 (averafging is made ine by line) :rtype: tensor0.dtype """ s = tensor0.size() assert s[0] == mask.size()[0] m = mask for i in range(len(s) - 1): m = mask.unsqueeze(-1) m = m.repeat(1, *s[1:]) m = m.float() out = ((1.0 - m) * tensor0 + m * tensor1).type(tensor0.dtype) return out
[docs]def masked_dicttensor(dicttensor0, dicttensor1, mask): """ Same as `masked_tensor`, but for DictTensor """ variables = {} for k in dicttensor0.keys(): v0 = dicttensor0[k] v1 = dicttensor1[k] variables[k] = masked_tensor(v0, v1, mask) return DictTensor(variables)
[docs]class DictTensor: """ A dictionary of torch.Tensor. The first dimension of each tensor is the batch dimension such that all tensors have the same batch dimension size. """ def __init__(self, v: Dict = None): """Initialize the DictTensor with a dictionary of Tensors. All tensors must have the same first dimension size. """ if v is None: self.variables = {} else: self.variables = v d = None for k in self.variables.values(): assert isinstance(k, torch.Tensor) if d is None: d = k.device else: assert d == k.device def _check(self) -> bool: """ Check that all tensors have the same batch dimension size. """ s = None for v in self.variables.values(): if s is None: s = v.size()[0] else: assert s == v.size()[0]
[docs] def keys(self) -> Iterable[str]: """ Return the keys of the DictTensor (as an iterator) """ return self.variables.keys()
def __getitem__(self, key: str) -> torch.Tensor: """Get one particular tensor in the DictTensor :param key: the name of the tensor :type key: str :return: the correspondiong tensor :rtype: torch.Tensor """ return self.variables[key]
[docs] def get(self, keys: Iterable[str], clone=False) -> DictTensor: """Returns a DictTensor composed of a subset of the tensors specifed by their keys :param keys: The keys to keep in the new DictTensor :type keys: Iterable[str] :param clone: if True, the new DictTensor is composed of clone of the original tensors, defaults to False :type clone: bool, optional :rtype: DictTensor """ d = DictTensor({k: self.variables[k] for k in keys}) if clone: return d.clone() else: return d
[docs] def clone(self) -> DictTensor: """Clone the dicttensor by cloning all its tensors :rtype: DictTensor """ return DictTensor({k: self.variables[k].clone() for k in self.variables})
[docs] def specs(self): """ Return the specifications of the dicttensor as a dictionary """ _specs = {} for k in self.variables: _specs[k] = { "size": self.variables[k][0].size(), "dtype": self.variables[k].dtype, } return _specs
[docs] def device(self) -> torch.device: """ Return the device of the tensors stored in the DictTensor. :rtype: torch.device """ if self.empty(): return None return next(iter(self.variables.values())).device
[docs] def n_elems(self) -> int: """ Return the size of size of the batch dimension (i.e the first dimension of the tensors) """ if len(self.variables) > 0: f = next(iter(self.variables.values())) return f.size()[0] # TODO: Empty dicts should be handled better than this return 0
[docs] def empty(self) -> bool: """Is the DictTensor empty? (no tensors in it) :rtype: bool """ return len(self.variables) == 0
[docs] def unfold(self) -> List[DictTensor]: """ Returns a list of DictTensor, each DictTensor capturing one element of the batch dimension (i.e suc that n_elems()==1) """ r = [] for i in range(self.n_elems()): v = {k: self.variables[k][i].unsqueeze(0) for k in self.variables} pt = DictTensor(v) r.append(pt) return r
[docs] def slice(self, index_from: int, index_to: int = None) -> DictTensor: """Returns a dict tensor, keeping only batch dimensions between index_from and index_to+1 :param index_from: The first batch index to keep :type index_from: int :param index_to: The last+1 batch index to keep. If None, then just index_from is kept :type index_to: int, optional :rtype: DictTensor """ if not index_to is None: v = {} for k in self.variables: v[k] = self.variables[k][index_from:index_to] return DictTensor(v) else: v = {} for k in self.variables: v[k] = self.variables[k][index_from] return DictTensor(v)
[docs] def index(self, index: int) -> DictTensor: """ The same as self.slice(index) """ v = {k: self.variables[k][index] for k in self.variables} return DictTensor(v)
[docs] def cat(tensors: Iterable[DictTensor]) -> DictTensor: """ Aggregate multiple packed tensors over the batch dimension Args: tensors (list): a list of tensors """ if len(tensors) == 0: return DictTensor({}) retour = {} for key in tensors[0].variables: to_concat = [] for n in range(len(tensors)): v = tensors[n][key] to_concat.append(v) retour[key] = torch.cat(to_concat, dim=0) return DictTensor(retour)
[docs] def to(self, device: torch.device): """ Create a copy of the DictTensor on a new device (if needed) """ if self.empty(): return DictTensor({}) if device == self.device(): return self.clone() v = {} for k in self.variables: v[k] = self.variables[k].to(device) return DictTensor(v)
[docs] def set(self, key: str, value: torch.Tensor): """ Add a tensor to the DictTensor Args: key (str): the name of the tensor value (torch.Tensor): the tensor to add, with a correct batch dimension size """ assert value.size()[0] == self.n_elems() assert isinstance(value, torch.Tensor) assert self.empty() or value.device == self.device() self.variables[key] = value
[docs] def prepend_key(self, _str: str) -> DictTensor: """ Return a new DictTensor where _str has been concatenated to all the keys """ v = {_str + key: self.variables[key] for key in self.variables} return DictTensor(v)
[docs] def truncate_key(self, _str: str) -> DictTensor: """ Return a new DictTensor where _str has been removed to all the keys that have _str as a prefix """ v = {} for k in self.variables: if k.startswith(_str): nk = k[len(_str) :] v[nk] = self.variables[k] return DictTensor(v)
def __str__(self): return "DictTensor: " + str(self.variables) def __contains__(self, key: str) -> bool: return key in self.variables def __add__(self, dt: DictTensor) -> DictTensor: """ Create a new DictTensor containing all the tensors from self and dt """ if self.empty(): return dt.clone() if dt.empty(): return self.clone() assert dt.device() == self.device() for k in dt.keys(): assert not k in self.variables, ( "variable " + k + " already exists in the DictTensor" ) v = {**self.variables, **dt} return DictTensor(v)
[docs] def copy_(self, source, source_indexes, destination_indexes): """ Copy the values of a source TDT at given indexes to the current TDT at the specified indexes """ assert source_indexes.size() == destination_indexes.size() for k in self.variables.keys(): self.variables[k][destination_indexes] = source[k][source_indexes]
[docs]class TemporalDictTensor: """ Describe a batch of temporal tensors where: * each tensor has a name * each tensor is of size B x T x ...., where B is the batch index, and T the time index * the length tensor gives the number of timesteps for each batch It is an extension of DictTensor where a temporal dimension has been added. The structure also allows dealing with batches of sequences of different sizes. Note that self.lengths returns a tensor of the lengths of each element of the batch. """ def __init__(self, from_dict: Dict[torch.Tensor], lengths: torch.Tensor = None): """ Args: from_dict (dict of tensors): the tensors to store. lengths (long tensor): the length of each element in the batch. If None, then use the second dimension of the tensors to compute the length. """ self.variables = from_dict self._keys = list(self.variables.keys()) self.lengths = lengths if self.lengths is None: self.lengths = ( torch.ones(self.n_elems()).long() * self.variables[self._keys[0]].size()[1] ) assert self.lengths.dtype == torch.int64 self._specs = None
[docs] def clone(self): v = {k: self.variables[k].clone() for k in self.variables} return TemporalDictTensor(v, lengths=self.lengths.clone())
[docs] def set(self, name, tensor): self.variables[name] = tensor
[docs] def specs(self): if self._specs is None: s = Specs() for k in self.variables: s.add(k, self.variables[k][0][0].size(), self.variables.dtype) self._specs = s return self._specs
[docs] def device(self) -> torch.device: """ Returns the device of the TemporalDictTensor """ return self.lengths.device
[docs] def n_elems(self) -> int: """ Returns the number of element in the TemporalDictTensor (i.e size of the first dimension of each tensor). """ return self.variables[self._keys[0]].size()[0]
[docs] def keys(self) -> Iterable[str]: """ Returns the keys in the TemporalDictTensor """ return self.variables.keys()
[docs] def mask(self) -> torch.Tensor: """ Returns a mask over sequences based on the length of each trajectory Considering that the TemporalDictTensor is of size B x T, the mask is a float tensor (0.0 or 1.0) of size BxT. A 0.0 value means that the value at b x t is not set in the TemporalDictTensor. """ for k in self.variables: max_length = self.variables[k].size()[1] _mask = ( torch.arange(max_length) .to(self.lengths.device) .unsqueeze(0) .repeat(self.n_elems(), 1) ) _mask = _mask.lt(self.lengths.unsqueeze(1).repeat(1, max_length)).float() return _mask
def __getitem__(self, key: str) -> torch.Tensor: """ Returns a single tensor of size B x T x .... Args: key (str): the name of the variable """ return self.variables[key]
[docs] def shorten(self) -> TemporalDictTensor: """ Restrict the size of the variables (in term of timesteps) to provide the smallest possible tensors. If the TemporalDictTensor is of size B x T, considering that Tmax = self.lengths.max(), then it returns a TemporalDictTensor of size B x Tmax """ ml = self.lengths.max() v = {k: self.variables[k][:, :ml] for k in self.variables} pt = TemporalDictTensor(v, self.lengths.clone()) return pt
[docs] def unfold(self) -> List[TemporalDictTensor]: """ Return a list of TemporalDictTensor of size 1 x T """ r = [] for i in range(self.n_elems()): v = {k: self.variables[k][i].unsqueeze(0) for k in self.variables} l = self.lengths[i].unsqueeze(0) pt = TemporalDictTensor(v, l) r.append(pt) return r
[docs] def get(self, keys: Iterable[str]) -> TemporalDictTensor: """ Returns a subset of the TemporalDictTensor depending on the specifed keys Args: keys (iterable): the keys to keep in the new TemporalDictTensor """ assert not isinstance(keys, str) return TemporalDictTensor({k: self.variables[k] for k in keys}, self.lengths)
[docs] def slice(self, index_from: int, index_to: int = None) -> TemporalDictTensor: """ Returns a slice (in the batch dimension) """ if not index_to is None: v = {k: self.variables[k][index_from:index_to] for k in self.variables} l = self.lengths[index_from:index_to] return TemporalDictTensor(v, l) else: v = {k: self.variables[k][index_from].unsqueeze(0) for k in self.variables} l = self.lengths[index_from].unsqueeze(0) return TemporalDictTensor(v, l)
[docs] def temporal_slice(self, index_from: int, index_to: int) -> TemporalDictTensor: """ Returns a slice (in the temporal dimension) """ v = {k: self.variables[k][:, index_from:index_to] for k in self.variables} # Compute new length l = self.lengths - index_from l = torch.clamp(l, 0) m = torch.ones(*l.size()) * (index_to - index_from) m = m.to(self.device()) low = l.lt(m).float() m = low * l + (1 - low) * m return TemporalDictTensor(v, m.long())
[docs] def index(self, index: int) -> TemporalDictTensor: """ Returns the 1xT TemporalDictTensor for the specified batch index """ v = {k: self.variables[k][index][:] for k in self.variables} l = self.lengths[index] return TemporalDictTensor(v, l)
[docs] def temporal_index(self, index_t: int) -> TemporalDictTensor: """ Return a DictTensor corresponding to the TemporalDictTensor at time index_t. """ return DictTensor({k: self.variables[k][:, index_t] for k in self.variables})
[docs] def temporal_multi_index(self, index_t: torch.Tensor) -> TemporalDictTensor: """ Return a DictTensor corresponding to the TemporalDictTensor at time index_t """ a = torch.arange(self.n_elems()).to(self.device()) return DictTensor({k: self.variables[k][a, index_t] for k in self.variables})
[docs] def masked_temporal_index(self, index_t: int) -> [DictTensor, torch.Tensor]: """ Return a DictTensor at time t along with a mapping vector Considering the TemporalDictTensor is of size BxT, the method returns a TemporalDictTensor of size B'xT and a tensor of size B' where: * only the B' relevant dimension has been kept (depending on the index_t < self.lengths criterion) * the mapping vector maps each of the B' dimension to the B dimension of the original TemporalDictTensor """ m = torch.tensor([index_t]).repeat(self.n_elems()) m = m.lt(self.lengths) v = {k: self.variables[k][m, index_t] for k in self.variables} m = torch.arange(self.n_elems())[m] return DictTensor(v), m
[docs] def cat(tensors: Iterable[TemporalDictTensor]) -> TemporalDictTensor: """ Aggregate multiple packed tensors over the batch dimension Args: tensors (list): a list of tensors """ lengths = torch.cat([t.lengths for t in tensors]) lm = lengths.max().item() retour = {} for key in tensors[0].keys(): to_concat = [] for n in range(len(tensors)): v = tensors[n][key] s = v.size() s = (s[0],) + (lm - s[1],) + s[2:] if s[1] > 0: toadd = torch.zeros(s, dtype=v.dtype) v = torch.cat([v, toadd], dim=1) to_concat.append(v) retour[key] = torch.cat(to_concat, dim=0) return TemporalDictTensor(retour, lengths)
[docs] def to(self, device: torch.device): """ Returns a copy of the TemporalDictTensor to the provided device (if needed). """ if device == self.device(): return self lengths = self.lengths.to(device) v = {} for k in self.variables: v[k] = self.variables[k].to(device) return TemporalDictTensor(v, lengths)
def __str__(self): r = ["TemporalDictTensor:"] for k in self.variables: r.append(k + ":" + str(self.variables[k].size())) r.append("Lengths =" + str(self.lengths.numpy())) return " ".join(r) def __contains__(self, item: str) -> bool: return item in self.variables
[docs] def full(self): """returns True if self.lengths==self.lengts.max() => No empty element""" return self.mask().sum() == 0.0
[docs] def expand(self, new_batch_size): """ Expand a TemporalDictTensor to reach a given batch_size """ assert new_batch_size > self.n_elems() diff = new_batch_size - self.n_elems() new_lengths = torch.zeros(new_batch_size).long().to(self.device()) new_lengths[0 : self.n_elems()] = self.lengths new_variables = {} for k in self.variables.keys(): s = self.variables[k].size() zeros = torch.zeros(diff, *s[1:]).to(self.device()) nv = torch.cat([self.variables[k], zeros]) new_variables[k] = nv return TemporalDictTensor(new_variables, new_lengths)
[docs] def copy_(self, source, source_indexes, destination_indexes): """ Copy the values of a source TDT at given indexes to the current TDT at the specified indexes """ assert source_indexes.size() == destination_indexes.size() max_length_source = source.lengths.max().item() for k in self.variables.keys(): self.variables[k][destination_indexes, 0:max_length_source] = source[k][ source_indexes, 0:max_length_source ] self.lengths[destination_indexes] = source.lengths[source_indexes]
[docs]class Trajectories: def __init__(self, info, trajectories): self.info = info self.trajectories = trajectories assert info.empty() or self.info.device() == self.trajectories.device() assert self.info.empty() or self.info.n_elems() == self.trajectories.n_elems()
[docs] def to(self, device): return Trajectories(self.info.to(device), self.trajectories.to(device))
[docs] def device(self): return self.trajectories.device()
[docs] def cat(trajectories: Iterable[Trajectories]): return Trajectories( DictTensor.cat([t.info for t in trajectories]), TemporalDictTensor([t.trajectories for t in trajectories]), )
[docs] def n_elems(self): return self.trajectories.n_elems()
[docs] def sample(self, n): raise NotImplementedError