salina.workspace
salina.workspace
CompactSharedTensor
It corresponds to a tensor in shared memory and is used when building a workspace shared by multiple processes.
All the methods behaves like the methods of SlicedTemporalTensor
CompactTemporalTensor
A CompactTemporalTensor is a tenosr of size TxBx... It behaves like the SlicedTemporalTensor but has a fixed size that cannot change. It is faster than the SlicedTemporalTensor.
See SlicedTemporalTensor
to_sliced(self)
Transform the tensor to a sSlicedTemporalTensor
Source code in salina/workspace.py
def to_sliced(self) -> SlicedTemporalTensor :
""" Transform the tensor to a s`SlicedTemporalTensor`
"""
v = SlicedTemporalTensor()
for t in range(self.tensor.size()[0]):
v.set(t, self.tensor[t], None)
return v
SlicedTemporalTensor
A SlicedTemporalTensor represents a tensor of size TxBx... by using a list of tensors of size Bx... The interest is that this tensor automatically adapts its timestep dimension and does not need to have a predefined size.
__init__(self)
special
Initialize an empty tensor
Source code in salina/workspace.py
def __init__(self):
""" Initialize an empty tensor
"""
self.tensors: list[torch.Tensor] = []
self.size: toch.Size = None
self.device: torch.device = None
self.dtype: torch.dtype = None
batch_size(self)
Return the size of the batch dimesion
Source code in salina/workspace.py
def batch_size(self):
"""Return the size of the batch dimesion
"""
return self.tensors[0].size()[0]
clear(self)
Clear the tensor
Source code in salina/workspace.py
def clear(self):
""" Clear the tensor
"""
self.tensors = []
self.size = None
self.device = None
self.dtype = None
copy_time(self, from_time, to_time, n_steps)
Copy temporal slices of the tensor from from_time:from_time+n_steps to to_time:to_time+n_steps
Source code in salina/workspace.py
def copy_time(self, from_time:int, to_time:int, n_steps:int):
""" Copy temporal slices of the tensor from from_time:from_time+n_steps to to_time:to_time+n_steps
"""
for t in range(n_steps):
v = self.get(from_time + t, batch_dims=None)
self.set(to_time + t, v, batch_dims=None)
get(self, t, batch_dims)
Get the value of the tensor at time t
Source code in salina/workspace.py
def get(self, t:int, batch_dims:Optional[tuple(int,int)]):
"""Get the value of the tensor at time t"""
assert (
batch_dims is None
), "Unable to use batch dimensions with SlicedTemporalTensor"
assert t < len(self.tensors), "Temporal index out of bouds"
return self.tensors[t]
get_full(self, batch_dims)
Returns the complete tensor of size TxBx...
Source code in salina/workspace.py
def get_full(self, batch_dims):
"""Returns the complete tensor of size TxBx..."""
assert (
batch_dims is None
), "Unable to use batch dimensions with SlicedTemporalTensor"
return torch.cat([a.unsqueeze(0) for a in self.tensors], dim=0)
get_time_truncated(self, from_time, to_time, batch_dims)
Returns tensor[from_time:to_time]
Source code in salina/workspace.py
def get_time_truncated(self, from_time:int, to_time:int, batch_dims:Optional[tuple(int,int)]):
"""Returns tensor[from_time:to_time]"""
assert from_time >= 0 and to_time >= 0 and to_time > from_time
assert batch_dims is None
return torch.cat(
[
self.tensors[k].unsqueeze(0)
for k in range(from_time, min(len(self.tensors), to_time))
],
dim=0,
)
select_batch(self, batch_indexes)
Return the tensor where the batch dimension has been selected by the index
Source code in salina/workspace.py
def select_batch(self, batch_indexes:torch.LongTensor):
""" Return the tensor where the batch dimension has been selected by the index
"""
var = SlicedTemporalTensor()
for t, v in enumerate(self.tensors):
var.set(t, v[batch_indexes], None)
return var
set(self, t, value, batch_dims)
Set a value (dim Bx...) at time t
Source code in salina/workspace.py
def set(self, t:int, value:torch.Tensor, batch_dims:Optional[tuple(int,int)]):
"""Set a value (dim Bx...) at time t
"""
assert (
batch_dims is None
), "Unable to use batch dimensions with SlicedTemporalTensor"
if self.size is None:
self.size = value.size()
self.device = value.device
self.dtype = value.dtype
assert self.size == value.size(), "Incompatible size"
assert self.device == value.device, "Incompatible device"
assert self.dtype == value.dtype, "Incompatible type"
while len(self.tensors) <= t:
self.tensors.append(
torch.zeros(*self.size, device=self.device, dtype=self.dtype)
)
self.tensors[t] = value
set_full(self, value, batch_dims)
Set the tensor given a BxTx... tensor. The input tensor is cut into slices that are stored in a list of tensors
Source code in salina/workspace.py
def set_full(self, value:torch.Tensor, batch_dims:Optional[tuple(int,int)]):
""" Set the tensor given a BxTx... tensor. The input tensor is cut into slices that are stored in a list of tensors
"""
assert (
batch_dims is None
), "Unable to use batch dimensions with SlicedTemporalTensor"
for t in range(value.size()[0]):
self.set(t, value[t], batch_dims=batch_dims)
subtime(self, from_t, to_t)
Return tensor[from_t:to_t]
Source code in salina/workspace.py
def subtime(self, from_t:int, to_t:int):
"""
Return tensor[from_t:to_t]
"""
return CompactTemporalTensor(
torch.cat([a.unsqueeze(0) for a in self.tensors[from_t:to_t]], dim=0)
)
time_size(self)
Return the size of the time dimension
Source code in salina/workspace.py
def time_size(self):
"""
Return the size of the time dimension
"""
return len(self.tensors)
to(self, device)
Move the tensor to a specific device
Source code in salina/workspace.py
def to(self, device:torch.device):
"""Move the tensor to a specific device"""
s = SlicedTemporalTensor()
for k in range(len(self.tensors)):
s.set(k, self.tensors[k].to(device))
return s
zero_grad(self)
Clear any gradient information in the tensor
Source code in salina/workspace.py
def zero_grad(self):
"""Clear any gradient information in the tensor
"""
self.tensors = [v.detach() for v in self.tensors]
Workspace
Workspace is the most import class in SaLinA. It correponds to a collection of tensors ('SlicedTemporalTensor,CompactTemporalTensoror CompactShareTensor`).
In the majority of cases, we consider that all the tensors have the same time and batch sizes (but it is not mandatory for most of the functions)
__getitem__(self, key)
special
if key is a string, then it returns a torch.Tensor if key is a list of string, it returns a tuple of torch.Tensor
Source code in salina/workspace.py
def __getitem__(self, key):
""" if key is a string, then it returns a torch.Tensor
if key is a list of string, it returns a tuple of torch.Tensor
"""
if isinstance(key, str):
return self.get_full(key, None)
else:
return (self.get_full(k, None) for k in key)
__init__(self, workspace=None)
special
Create an empty workspace
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
workspace |
Workspace |
If specified, it creates a copy of the workspace (where tensors are cloned as CompactTemporalTensors) |
None |
Source code in salina/workspace.py
def __init__(self, workspace:Optional[Workspace]=None):
""" Create an empty workspace
Args:
workspace (Workspace, optional): If specified, it creates a copy of the workspace (where tensors are cloned as CompactTemporalTensors)
"""
self.variables = {}
self.is_shared = False
if not workspace is None:
for k in workspace.keys():
self.set_full(k, workspace[k].clone())
batch_size(self)
Return the batch size of the variables in the workspace
Source code in salina/workspace.py
def batch_size(self) -> int :
""" Return the batch size of the variables in the workspace
"""
_bs = None
for k, v in self.variables.items():
if _bs is None:
_bs = v.batch_size()
assert _bs == v.batch_size(), "Variables must have the same batch size"
return _bs
cat_batch(workspaces)
Concatenate multiple workspaces over the batch dimension. The workspaces must have the same time dimension.
Source code in salina/workspace.py
def cat_batch(workspaces:list[Workspace]) -> Workspace:
""" Concatenate multiple workspaces over the batch dimension. The workspaces must have the same time dimension.
"""
ts = None
for w in workspaces:
if ts is None:
ts = w.time_size()
assert ts == w.time_size(), "Workspaces must have the same time_size"
workspace = Workspace()
for k in workspaces[0].keys():
vals = [w[k] for w in workspaces]
v = torch.cat(vals, dim=1)
workspace.set_full(k, v)
return workspace
clear(self)
Remove all the variables from the workspace
Source code in salina/workspace.py
def clear(self):
""" Remove all the variables from the workspace
"""
for k, v in self.variables.items():
v.clear()
contiguous(self)
Generates a workspace where all tensors are stored in the Compact format.
Source code in salina/workspace.py
def contiguous(self) -> Workspace:
""" Generates a workspace where all tensors are stored in the Compact format.
"""
workspace=Workspace()
for k in self.keys():
workspace.set_full(k,self.get_full(k))
return workspace
copy_n_last_steps(self, n, var_names=None)
Copy the n last timesteps of each variables to the n first timesteps.
Source code in salina/workspace.py
def copy_n_last_steps(self, n:int, var_names:Optional[list(str)]=None):
""" Copy the n last timesteps of each variables to the n first timesteps.
"""
_ts = None
for k, v in self.variables.items():
if var_names is None or k in var_names:
if _ts is None:
_ts = v.time_size()
assert _ts == v.time_size(), "Variables must have the same time size"
for k, v in self.variables.items():
if var_names is None or k in var_names:
self.copy_time(_ts - n, 0, n)
copy_time(self, from_time, to_time, n_steps, var_names=None)
Copy all the variables values from time from_time to from_time+n_steps to to_time to to_time+n_steps
It can be restricted to specific variables uusing var_names
Source code in salina/workspace.py
def copy_time(self, from_time:int, to_time:int, n_steps:int, var_names:Optional[list[str]]=None):
""" Copy all the variables values from time `from_time` to `from_time+n_steps` to `to_time` to `to_time+n_steps`
It can be restricted to specific variables uusing `var_names`
"""
for k, v in self.variables.items():
if var_names is None or k in var_names:
v.copy_time(from_time, to_time, n_steps)
get(self, var_name, t, batch_dims=None)
Get the variable var_name at time t
Source code in salina/workspace.py
def get(self, var_name:str, t:int, batch_dims:Optional[tuple(int,int)]=None) -> torch.Tensor:
""" Get the variable var_name at time t
"""
assert var_name in self.variables, "Unknoanw variable '" + var_name + "'"
return self.variables[var_name].get(t, batch_dims=batch_dims)
get_full(self, var_name, batch_dims=None)
Return the complete tensor for var_name
Source code in salina/workspace.py
def get_full(self, var_name:str, batch_dims:Optional[tuple(int,int)]=None) -> torch.Tensor:
""" Return the complete tensor for var_name
"""
assert var_name in self.variables, (
"[Workspace.get_full] unnknown variable '" + var_name + "'"
)
return self.variables[var_name].get_full(batch_dims=batch_dims)
get_time_truncated(self, var_name, from_time, to_time, batch_dims=None)
Return workspace[var_name][from_time:to_time]
Source code in salina/workspace.py
def get_time_truncated(self, var_name:str, from_time:int, to_time:int, batch_dims:Optional[tuple(int,int)]=None) -> torch.Tensor:
""" Return workspace[var_name][from_time:to_time]
"""
assert from_time >= 0 and to_time >= 0 and to_time > from_time
v = self.variables[var_name]
if isinstance(v, SlicedTemporalTensor):
return v.get_time_truncated(from_time, to_time, batch_dims)
else:
return v.get_full(batch_dims)[from_time:to_time]
get_time_truncated_workspace(self, from_time, to_time)
Return a workspace where all variables are truncated between from_time and to_time
Source code in salina/workspace.py
def get_time_truncated_workspace(self,from_time:int, to_time:int) -> Workspace:
""" Return a workspace where all variables are truncated between from_time and to_time
"""
workspace=Workspace()
for k in self.keys():
workspace.set_full(k,self.get_time_truncated(k,from_time,to_time,None))
return workspace
keys(self)
Return an interator over the variables names
Source code in salina/workspace.py
def keys(self):
""" Return an interator over the variables names
"""
return self.variables.keys()
remove_variable(self, var_name)
Remove a variable from the Workspace
Source code in salina/workspace.py
def remove_variable(self, var_name:str):
""" Remove a variable from the Workspace
"""
del self.variables[var_name]
select_batch(self, batch_indexes)
Given a tensor of indexes, it returns a new workspace with the select elements (over the batch dimension)
Source code in salina/workspace.py
def select_batch(self, batch_indexes:torch.LongTensor) -> Workspace:
""" Given a tensor of indexes, it returns a new workspace with the select elements (over the batch dimension)
"""
_bs = None
for k, v in self.variables.items():
if _bs is None:
_bs = v.batch_size()
assert _bs == v.batch_size(), "Variables must have the same batch size"
workspace = Workspace()
for k, v in self.variables.items():
v = v.select_batch(batch_indexes)
workspace.variables[k] = v
return workspace
select_batch_n(self, n)
Return a new Workspace of batch_size==n by randomly sampling over the batch dimension
Source code in salina/workspace.py
def select_batch_n(self, n):
""" Return a new Workspace of batch_size==n by randomly sampling over the batch dimension
"""
who = torch.randint(low=0, high=self.batch_size(), size=(n,))
return self.select_batch(who)
select_subtime(self, t, window_size)
t is a tensor of size batch_size that provides one time index for each element of the workspace.
Then the function returns a new workspace by aggregating window_size timesteps starting from index t
This methods allows to sample multiple windows in the Workspace.
Note that the function may be quite slow.
Source code in salina/workspace.py
def select_subtime(self, t: torch.LongTensor, window_size:int) -> Workspace:
"""
`t` is a tensor of size `batch_size` that provides one time index for each element of the workspace.
Then the function returns a new workspace by aggregating `window_size` timesteps starting from index `t`
This methods allows to sample multiple windows in the Workspace.
Note that the function may be quite slow.
"""
_vars = {k: v.get_full(batch_dims=None) for k, v in self.variables.items()}
workspace = Workspace()
for k, v in _vars.items():
workspace.set_full(
k, take_per_row_strided(v, t, num_elem=window_size), batch_dims=None
)
return workspace
set(self, var_name, t, v, batch_dims=None)
Set the variable var_name at time t
Source code in salina/workspace.py
def set(self, var_name:str, t:int, v:torch.Tensor, batch_dims:Optional[tuple(int,int)]=None):
""" Set the variable var_name at time t
"""
if not var_name in self.variables:
assert not self.is_shared, "Cannot add new variable into a shared workspace"
self.variables[var_name] = SlicedTemporalTensor()
elif isinstance(self.variables[var_name], CompactTemporalTensor):
self.variables[var_name] = self.variables[var_name].to_sliced()
self.variables[var_name].set(t, v, batch_dims=batch_dims)
set_full(self, var_name, value, batch_dims=None)
Set variable var_name with a complete tensor (TxBx...)
Source code in salina/workspace.py
def set_full(self, var_name:str, value:torch.Tensor, batch_dims:Optional[tuple(int,int)]=None):
""" Set variable var_name with a complete tensor (TxBx...)
"""
if not var_name in self.variables:
assert not self.is_shared, "Cannot add new variable into a shared workspace"
self.variables[var_name] = CompactTemporalTensor()
self.variables[var_name].set_full(value, batch_dims=batch_dims)
subtime(self, from_t, to_t)
Return a workspace restricted to a subset of the time dimension
Source code in salina/workspace.py
def subtime(self, from_t:int, to_t:int) -> Workspace:
"""
Return a workspace restricted to a subset of the time dimension
"""
assert (
self._all_variables_same_time_size()
), "All variables must have the same time size"
workspace = Workspace()
for k, v in self.variables.items():
workspace.variables[k] = v.subtime(from_t, to_t)
return workspace
time_size(self)
Return the time size of the variables in the workspace
Source code in salina/workspace.py
def time_size(self) -> int :
""" Return the time size of the variables in the workspace
"""
_ts = None
for k, v in self.variables.items():
if _ts is None:
_ts = v.time_size()
assert _ts == v.time_size(), "Variables must have the same time size"
return _ts
to(self, device)
Return a workspace where all tensors are on a particular device
Source code in salina/workspace.py
def to(self, device:torch.device) -> Workspace:
""" Return a workspace where all tensors are on a particular device
"""
workspace = Workspace()
for k, v in self.variables.items():
workspace.variables[k] = v.to(device)
return workspace
zero_grad(self)
Remove any gradient information
Source code in salina/workspace.py
def zero_grad(self):
""" Remove any gradient information
"""
for k, v in self.variables.items():
v.zero_grad()