Skip to content

salina.agents.transformers

salina.agents.transformers.TransformerMultiBlockAgent (Agents)

__init__(self, n_layers, embedding_size, n_heads, n_steps=None, prefix='attn_', use_layer_norm=False) special

A agent that is a transformers architecture. The agent will read the prefix+'in' variable and output the prefix+'out' variable.

Parameters:

Name Type Description Default
n_layers [int]

Number of layers

required
embedding_size [int]

Size of the vectors

required
n_heads [int]

number of heads

required
n_steps [int]

If >0 then, it corresponds to the number of steps to look back. Defaults to None.

None
prefix str

The name of the variable in the workspace. Defaults to "attn_".

'attn_'
use_layer_norm bool

With/without layer normalization. Defaults to False.

False
Source code in salina/agents/transformers.py
def __init__(
    self,
    n_layers,
    embedding_size,
    n_heads,
    n_steps=None,
    prefix="attn_",
    use_layer_norm=False,
):
    """ A agent that is a transformers architecture. The agent will read the `prefix+'in'` variable and output the `prefix+'out'` variable.

    Args:
        n_layers ([int]): Number of layers
        embedding_size ([int]): Size of the vectors
        n_heads ([int]): number of heads
        n_steps ([int], optional): If >0 then, it corresponds to the number of steps to look back. Defaults to None.
        prefix (str, optional): The name of the variable in the workspace. Defaults to "attn_".
        use_layer_norm (bool, optional): With/without layer normalization. Defaults to False.
    """
    agents = []
    for k in range(n_layers):
        in_prefix = prefix + str(k + 1)
        out_prefix = prefix + str(k + 2)
        if k == n_layers - 1:
            out_prefix = prefix + "out"
        if k == 0:
            in_prefix = prefix + "in"
        agents.append(
            TransformerBlockAgent(
                embedding_size,
                n_heads,
                n_steps,
                in_prefix + "/x",
                out_prefix + "/x",
                use_layer_norm=use_layer_norm,
            )
        )
    super().__init__(*agents)