Skip to content

salina.workspace

salina.workspace

CompactSharedTensor

It corresponds to a tensor in shared memory and is used when building a workspace shared by multiple processes. All the methods behaves like the methods of SlicedTemporalTensor

CompactTemporalTensor

A CompactTemporalTensor is a tenosr of size TxBx... It behaves like the SlicedTemporalTensor but has a fixed size that cannot change. It is faster than the SlicedTemporalTensor. See SlicedTemporalTensor

to_sliced(self)

Transform the tensor to a sSlicedTemporalTensor

Source code in salina/workspace.py
def to_sliced(self) -> SlicedTemporalTensor :
    """ Transform the tensor to a s`SlicedTemporalTensor`
    """
    v = SlicedTemporalTensor()
    for t in range(self.tensor.size()[0]):
        v.set(t, self.tensor[t], None)
    return v

SlicedTemporalTensor

A SlicedTemporalTensor represents a tensor of size TxBx... by using a list of tensors of size Bx... The interest is that this tensor automatically adapts its timestep dimension and does not need to have a predefined size.

__init__(self) special

Initialize an empty tensor

Source code in salina/workspace.py
def __init__(self):
    """ Initialize an empty tensor
    """
    self.tensors: list[torch.Tensor] = []
    self.size: toch.Size = None
    self.device: torch.device = None
    self.dtype: torch.dtype = None

batch_size(self)

Return the size of the batch dimesion

Source code in salina/workspace.py
def batch_size(self):
    """Return the size of the batch dimesion

    """
    return self.tensors[0].size()[0]

clear(self)

Clear the tensor

Source code in salina/workspace.py
def clear(self):
    """ Clear the tensor
    """
    self.tensors = []
    self.size = None
    self.device = None
    self.dtype = None

copy_time(self, from_time, to_time, n_steps)

Copy temporal slices of the tensor from from_time:from_time+n_steps to to_time:to_time+n_steps

Source code in salina/workspace.py
def copy_time(self, from_time:int, to_time:int, n_steps:int):
    """ Copy temporal slices of the tensor from from_time:from_time+n_steps to to_time:to_time+n_steps
    """
    for t in range(n_steps):
        v = self.get(from_time + t, batch_dims=None)
        self.set(to_time + t, v, batch_dims=None)

get(self, t, batch_dims)

Get the value of the tensor at time t

Source code in salina/workspace.py
def get(self, t:int, batch_dims:Optional[tuple(int,int)]):
    """Get the value of the tensor at time t"""

    assert (
        batch_dims is None
    ), "Unable to use batch dimensions with SlicedTemporalTensor"
    assert t < len(self.tensors), "Temporal index out of bouds"
    return self.tensors[t]

get_full(self, batch_dims)

Returns the complete tensor of size TxBx...

Source code in salina/workspace.py
def get_full(self, batch_dims):
    """Returns the complete tensor of size TxBx..."""

    assert (
        batch_dims is None
    ), "Unable to use batch dimensions with SlicedTemporalTensor"
    return torch.cat([a.unsqueeze(0) for a in self.tensors], dim=0)

get_time_truncated(self, from_time, to_time, batch_dims)

Returns tensor[from_time:to_time]

Source code in salina/workspace.py
def get_time_truncated(self, from_time:int, to_time:int, batch_dims:Optional[tuple(int,int)]):
    """Returns tensor[from_time:to_time]"""
    assert from_time >= 0 and to_time >= 0 and to_time > from_time
    assert batch_dims is None
    return torch.cat(
        [
            self.tensors[k].unsqueeze(0)
            for k in range(from_time, min(len(self.tensors), to_time))
        ],
        dim=0,
    )

select_batch(self, batch_indexes)

Return the tensor where the batch dimension has been selected by the index

Source code in salina/workspace.py
def select_batch(self, batch_indexes:torch.LongTensor):
    """ Return the tensor where the batch dimension has been selected by the index

    """
    var = SlicedTemporalTensor()
    for t, v in enumerate(self.tensors):
        var.set(t, v[batch_indexes], None)
    return var

set(self, t, value, batch_dims)

Set a value (dim Bx...) at time t

Source code in salina/workspace.py
def set(self, t:int, value:torch.Tensor, batch_dims:Optional[tuple(int,int)]):
    """Set a value (dim Bx...) at time t
    """
    assert (
        batch_dims is None
    ), "Unable to use batch dimensions with SlicedTemporalTensor"
    if self.size is None:
        self.size = value.size()
        self.device = value.device
        self.dtype = value.dtype
    assert self.size == value.size(), "Incompatible size"
    assert self.device == value.device, "Incompatible device"
    assert self.dtype == value.dtype, "Incompatible type"
    while len(self.tensors) <= t:
        self.tensors.append(
            torch.zeros(*self.size, device=self.device, dtype=self.dtype)
        )
    self.tensors[t] = value

set_full(self, value, batch_dims)

Set the tensor given a BxTx... tensor. The input tensor is cut into slices that are stored in a list of tensors

Source code in salina/workspace.py
def set_full(self, value:torch.Tensor, batch_dims:Optional[tuple(int,int)]):
    """ Set the tensor given a BxTx... tensor. The input tensor is cut into slices that are stored in a list of tensors
    """
    assert (
        batch_dims is None
    ), "Unable to use batch dimensions with SlicedTemporalTensor"
    for t in range(value.size()[0]):
        self.set(t, value[t], batch_dims=batch_dims)

subtime(self, from_t, to_t)

Return tensor[from_t:to_t]

Source code in salina/workspace.py
def subtime(self, from_t:int, to_t:int):
    """
        Return tensor[from_t:to_t]

    """
    return CompactTemporalTensor(
        torch.cat([a.unsqueeze(0) for a in self.tensors[from_t:to_t]], dim=0)
    )

time_size(self)

Return the size of the time dimension

Source code in salina/workspace.py
def time_size(self):
    """
    Return the size of the time dimension
    """
    return len(self.tensors)

to(self, device)

Move the tensor to a specific device

Source code in salina/workspace.py
def to(self, device:torch.device):
    """Move the tensor to a specific device"""
    s = SlicedTemporalTensor()
    for k in range(len(self.tensors)):
        s.set(k, self.tensors[k].to(device))
    return s

zero_grad(self)

Clear any gradient information in the tensor

Source code in salina/workspace.py
def zero_grad(self):
    """Clear any gradient information in the tensor
    """
    self.tensors = [v.detach() for v in self.tensors]

Workspace

Workspace is the most import class in SaLinA. It correponds to a collection of tensors ('SlicedTemporalTensor,CompactTemporalTensoror CompactShareTensor`). In the majority of cases, we consider that all the tensors have the same time and batch sizes (but it is not mandatory for most of the functions)

__getitem__(self, key) special

if key is a string, then it returns a torch.Tensor if key is a list of string, it returns a tuple of torch.Tensor

Source code in salina/workspace.py
def __getitem__(self, key):
    """ if key is a string, then it returns a torch.Tensor
    if key is a list of string, it returns a tuple of torch.Tensor
    """
    if isinstance(key, str):
        return self.get_full(key, None)
    else:
        return (self.get_full(k, None) for k in key)

__init__(self, workspace=None) special

Create an empty workspace

Parameters:

Name Type Description Default
workspace Workspace

If specified, it creates a copy of the workspace (where tensors are cloned as CompactTemporalTensors)

None
Source code in salina/workspace.py
def __init__(self, workspace:Optional[Workspace]=None):
    """ Create an empty workspace

    Args:
        workspace (Workspace, optional): If specified, it creates a copy of the workspace (where tensors are cloned as CompactTemporalTensors)
    """
    self.variables = {}
    self.is_shared = False
    if not workspace is None:
        for k in workspace.keys():
            self.set_full(k, workspace[k].clone())

batch_size(self)

Return the batch size of the variables in the workspace

Source code in salina/workspace.py
def batch_size(self) -> int :
    """ Return the batch size of the variables in the workspace
    """
    _bs = None
    for k, v in self.variables.items():
        if _bs is None:
            _bs = v.batch_size()
        assert _bs == v.batch_size(), "Variables must have the same batch size"
    return _bs

cat_batch(workspaces)

Concatenate multiple workspaces over the batch dimension. The workspaces must have the same time dimension.

Source code in salina/workspace.py
def cat_batch(workspaces:list[Workspace]) -> Workspace:
    """ Concatenate multiple workspaces over the batch dimension. The workspaces must have the same time dimension.
    """

    ts = None
    for w in workspaces:
        if ts is None:
            ts = w.time_size()
        assert ts == w.time_size(), "Workspaces must have the same time_size"

    workspace = Workspace()
    for k in workspaces[0].keys():
        vals = [w[k] for w in workspaces]
        v = torch.cat(vals, dim=1)
        workspace.set_full(k, v)
    return workspace

clear(self)

Remove all the variables from the workspace

Source code in salina/workspace.py
def clear(self):
    """ Remove all the variables from the workspace
    """
    for k, v in self.variables.items():
        v.clear()

contiguous(self)

Generates a workspace where all tensors are stored in the Compact format.

Source code in salina/workspace.py
def contiguous(self) -> Workspace:
    """ Generates a workspace where all tensors are stored in the Compact format.
    """
    workspace=Workspace()
    for k in self.keys():
        workspace.set_full(k,self.get_full(k))
    return workspace

copy_n_last_steps(self, n, var_names=None)

Copy the n last timesteps of each variables to the n first timesteps.

Source code in salina/workspace.py
def copy_n_last_steps(self, n:int, var_names:Optional[list(str)]=None):
    """ Copy the n last timesteps of each variables to the n first timesteps.
    """
    _ts = None
    for k, v in self.variables.items():
        if var_names is None or k in var_names:
            if _ts is None:
                _ts = v.time_size()
            assert _ts == v.time_size(), "Variables must have the same time size"

    for k, v in self.variables.items():
        if var_names is None or k in var_names:
            self.copy_time(_ts - n, 0, n)

copy_time(self, from_time, to_time, n_steps, var_names=None)

Copy all the variables values from time from_time to from_time+n_steps to to_time to to_time+n_steps It can be restricted to specific variables uusing var_names

Source code in salina/workspace.py
def copy_time(self, from_time:int, to_time:int, n_steps:int, var_names:Optional[list[str]]=None):
    """ Copy all the variables values from time `from_time` to `from_time+n_steps` to `to_time` to `to_time+n_steps`
    It can be restricted to specific variables uusing `var_names`
    """
    for k, v in self.variables.items():
        if var_names is None or k in var_names:
            v.copy_time(from_time, to_time, n_steps)

get(self, var_name, t, batch_dims=None)

Get the variable var_name at time t

Source code in salina/workspace.py
def get(self, var_name:str, t:int, batch_dims:Optional[tuple(int,int)]=None) -> torch.Tensor:
    """ Get the variable var_name at time t
    """
    assert var_name in self.variables, "Unknoanw variable '" + var_name + "'"
    return self.variables[var_name].get(t, batch_dims=batch_dims)

get_full(self, var_name, batch_dims=None)

Return the complete tensor for var_name

Source code in salina/workspace.py
def get_full(self, var_name:str, batch_dims:Optional[tuple(int,int)]=None) -> torch.Tensor:
    """ Return the complete tensor for var_name
    """
    assert var_name in self.variables, (
        "[Workspace.get_full] unnknown variable '" + var_name + "'"
    )
    return self.variables[var_name].get_full(batch_dims=batch_dims)

get_time_truncated(self, var_name, from_time, to_time, batch_dims=None)

Return workspace[var_name][from_time:to_time]

Source code in salina/workspace.py
def get_time_truncated(self, var_name:str, from_time:int, to_time:int, batch_dims:Optional[tuple(int,int)]=None) -> torch.Tensor:
    """ Return workspace[var_name][from_time:to_time]
    """
    assert from_time >= 0 and to_time >= 0 and to_time > from_time

    v = self.variables[var_name]
    if isinstance(v, SlicedTemporalTensor):
        return v.get_time_truncated(from_time, to_time, batch_dims)
    else:
        return v.get_full(batch_dims)[from_time:to_time]

get_time_truncated_workspace(self, from_time, to_time)

Return a workspace where all variables are truncated between from_time and to_time

Source code in salina/workspace.py
def get_time_truncated_workspace(self,from_time:int, to_time:int) -> Workspace:
    """ Return a workspace where all variables are truncated between from_time and to_time
    """
    workspace=Workspace()
    for k in self.keys():
        workspace.set_full(k,self.get_time_truncated(k,from_time,to_time,None))
    return workspace

keys(self)

Return an interator over the variables names

Source code in salina/workspace.py
def keys(self):
    """ Return an interator over the variables names
    """
    return self.variables.keys()

remove_variable(self, var_name)

Remove a variable from the Workspace

Source code in salina/workspace.py
def remove_variable(self, var_name:str):
    """ Remove a variable from the Workspace
    """
    del self.variables[var_name]

select_batch(self, batch_indexes)

Given a tensor of indexes, it returns a new workspace with the select elements (over the batch dimension)

Source code in salina/workspace.py
def select_batch(self, batch_indexes:torch.LongTensor) -> Workspace:
    """ Given a tensor of indexes, it returns a new workspace with the select elements (over the batch dimension)
    """
    _bs = None
    for k, v in self.variables.items():
        if _bs is None:
            _bs = v.batch_size()
        assert _bs == v.batch_size(), "Variables must have the same batch size"

    workspace = Workspace()
    for k, v in self.variables.items():
        v = v.select_batch(batch_indexes)
        workspace.variables[k] = v
    return workspace

select_batch_n(self, n)

Return a new Workspace of batch_size==n by randomly sampling over the batch dimension

Source code in salina/workspace.py
def select_batch_n(self, n):
    """ Return a new Workspace of batch_size==n by randomly sampling over the batch dimension
    """
    who = torch.randint(low=0, high=self.batch_size(), size=(n,))
    return self.select_batch(who)

select_subtime(self, t, window_size)

t is a tensor of size batch_size that provides one time index for each element of the workspace. Then the function returns a new workspace by aggregating window_size timesteps starting from index t This methods allows to sample multiple windows in the Workspace. Note that the function may be quite slow.

Source code in salina/workspace.py
def select_subtime(self, t: torch.LongTensor, window_size:int) -> Workspace:
    """
    `t` is a tensor of size `batch_size` that provides one time index for each element of the workspace.
    Then the function returns a new workspace by aggregating `window_size` timesteps starting from index `t`
    This methods allows to sample multiple windows in the Workspace.
    Note that the function may be quite slow.
    """
    _vars = {k: v.get_full(batch_dims=None) for k, v in self.variables.items()}
    workspace = Workspace()
    for k, v in _vars.items():
        workspace.set_full(
            k, take_per_row_strided(v, t, num_elem=window_size), batch_dims=None
        )
    return workspace

set(self, var_name, t, v, batch_dims=None)

Set the variable var_name at time t

Source code in salina/workspace.py
def set(self, var_name:str, t:int, v:torch.Tensor, batch_dims:Optional[tuple(int,int)]=None):
    """ Set the variable var_name at time t
    """
    if not var_name in self.variables:
        assert not self.is_shared, "Cannot add new variable into a shared workspace"
        self.variables[var_name] = SlicedTemporalTensor()
    elif isinstance(self.variables[var_name], CompactTemporalTensor):
        self.variables[var_name] = self.variables[var_name].to_sliced()

    self.variables[var_name].set(t, v, batch_dims=batch_dims)

set_full(self, var_name, value, batch_dims=None)

Set variable var_name with a complete tensor (TxBx...)

Source code in salina/workspace.py
def set_full(self, var_name:str, value:torch.Tensor, batch_dims:Optional[tuple(int,int)]=None):
    """ Set variable var_name with a complete tensor (TxBx...)
    """
    if not var_name in self.variables:
        assert not self.is_shared, "Cannot add new variable into a shared workspace"
        self.variables[var_name] = CompactTemporalTensor()
    self.variables[var_name].set_full(value, batch_dims=batch_dims)

subtime(self, from_t, to_t)

Return a workspace restricted to a subset of the time dimension

Source code in salina/workspace.py
def subtime(self, from_t:int, to_t:int) -> Workspace:
    """
    Return a workspace restricted to a subset of the time dimension
    """
    assert (
        self._all_variables_same_time_size()
    ), "All variables must have the same time size"
    workspace = Workspace()
    for k, v in self.variables.items():
        workspace.variables[k] = v.subtime(from_t, to_t)
    return workspace

time_size(self)

Return the time size of the variables in the workspace

Source code in salina/workspace.py
def time_size(self) -> int :
    """ Return the time size of the variables in the workspace
    """
    _ts = None
    for k, v in self.variables.items():
        if _ts is None:
            _ts = v.time_size()
        assert _ts == v.time_size(), "Variables must have the same time size"
    return _ts

to(self, device)

Return a workspace where all tensors are on a particular device

Source code in salina/workspace.py
def to(self, device:torch.device) -> Workspace:
    """ Return a workspace where all tensors are on a particular device
    """
    workspace = Workspace()
    for k, v in self.variables.items():
        workspace.variables[k] = v.to(device)
    return workspace

zero_grad(self)

Remove any gradient information

Source code in salina/workspace.py
def zero_grad(self):
    """ Remove any gradient information
    """
    for k, v in self.variables.items():
        v.zero_grad()