Skip to content

salina.agents.dataloader

salina.agents.dataloader.ShuffledDatasetAgent (Agent)

An agent that read a dataset in a shuffle order, in an infinite way.

__init__(self, dataset, batch_size, output_names=('x', 'y')) special

Create the agent

Parameters:

Name Type Description Default
dataset [torch.utils.data.Dataset]

the Dataset

required
batch_size [int]

The number of datapoints to write at each call

required
output_names tuple

The name of the variables. Defaults to ("x", "y").

('x', 'y')
Source code in salina/agents/dataloader.py
def __init__(
    self,
    dataset,
    batch_size,
    output_names=("x", "y"),
):
    """Create the agent

    Args:
        dataset ([torch.utils.data.Dataset]): the Dataset
        batch_size ([int]): The number of datapoints to write at each call
        output_names (tuple, optional): The name of the variables. Defaults to ("x", "y").
    """
    super().__init__()
    self.output_names = output_names
    self.dataset = dataset
    self.batch_size = batch_size
    self.ghost_params = torch.nn.Parameter(torch.randn(()))

salina.agents.dataloader.DataLoaderAgent (Agent)

An agent based on a DataLoader that read a single dataset Usage is: agent.forward(), then one has to check if agent.finished() is True or Not. If True, then no data have been written in the workspace since the reading of the daaset is terminated

__init__(self, dataloader, output_names=('x', 'y')) special

Create the agent based on a dataloader

Parameters:

Name Type Description Default
dataloader [DataLader]

The underlying pytoch daaloader object

required
output_names tuple

Names of the variable to write in the workspace. Defaults to ("x", "y").

('x', 'y')
Source code in salina/agents/dataloader.py
def __init__(self, dataloader, output_names=("x", "y")):
    """ Create the agent based on a dataloader

    Args:
        dataloader ([DataLader]): The underlying pytoch daaloader object
        output_names (tuple, optional): Names of the variable to write in the workspace. Defaults to ("x", "y").
    """
    super().__init__()
    self.dataloader = dataloader
    self.iter = iter(self.dataloader)
    self.output_names = output_names
    self._finished = False
    self.ghost_params = torch.nn.Parameter(torch.randn(()))