salina.agents.dataloader
salina.agents.dataloader.ShuffledDatasetAgent (Agent)
An agent that read a dataset in a shuffle order, in an infinite way.
__init__(self, dataset, batch_size, output_names=('x', 'y'))
special
Create the agent
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dataset |
[torch.utils.data.Dataset] |
the Dataset |
required |
batch_size |
[int] |
The number of datapoints to write at each call |
required |
output_names |
tuple |
The name of the variables. Defaults to ("x", "y"). |
('x', 'y') |
Source code in salina/agents/dataloader.py
def __init__(
self,
dataset,
batch_size,
output_names=("x", "y"),
):
"""Create the agent
Args:
dataset ([torch.utils.data.Dataset]): the Dataset
batch_size ([int]): The number of datapoints to write at each call
output_names (tuple, optional): The name of the variables. Defaults to ("x", "y").
"""
super().__init__()
self.output_names = output_names
self.dataset = dataset
self.batch_size = batch_size
self.ghost_params = torch.nn.Parameter(torch.randn(()))
salina.agents.dataloader.DataLoaderAgent (Agent)
An agent based on a DataLoader that read a single dataset Usage is: agent.forward(), then one has to check if agent.finished() is True or Not. If True, then no data have been written in the workspace since the reading of the daaset is terminated
__init__(self, dataloader, output_names=('x', 'y'))
special
Create the agent based on a dataloader
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dataloader |
[DataLader] |
The underlying pytoch daaloader object |
required |
output_names |
tuple |
Names of the variable to write in the workspace. Defaults to ("x", "y"). |
('x', 'y') |
Source code in salina/agents/dataloader.py
def __init__(self, dataloader, output_names=("x", "y")):
""" Create the agent based on a dataloader
Args:
dataloader ([DataLader]): The underlying pytoch daaloader object
output_names (tuple, optional): Names of the variable to write in the workspace. Defaults to ("x", "y").
"""
super().__init__()
self.dataloader = dataloader
self.iter = iter(self.dataloader)
self.output_names = output_names
self._finished = False
self.ghost_params = torch.nn.Parameter(torch.randn(()))