salina.agents.transformers
salina.agents.transformers.TransformerMultiBlockAgent (Agents)
__init__(self, n_layers, embedding_size, n_heads, n_steps=None, prefix='attn_', use_layer_norm=False)
special
A agent that is a transformers architecture. The agent will read the prefix+'in' variable and output the prefix+'out' variable.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_layers |
[int] |
Number of layers |
required |
embedding_size |
[int] |
Size of the vectors |
required |
n_heads |
[int] |
number of heads |
required |
n_steps |
[int] |
If >0 then, it corresponds to the number of steps to look back. Defaults to None. |
None |
prefix |
str |
The name of the variable in the workspace. Defaults to "attn_". |
'attn_' |
use_layer_norm |
bool |
With/without layer normalization. Defaults to False. |
False |
Source code in salina/agents/transformers.py
def __init__(
self,
n_layers,
embedding_size,
n_heads,
n_steps=None,
prefix="attn_",
use_layer_norm=False,
):
""" A agent that is a transformers architecture. The agent will read the `prefix+'in'` variable and output the `prefix+'out'` variable.
Args:
n_layers ([int]): Number of layers
embedding_size ([int]): Size of the vectors
n_heads ([int]): number of heads
n_steps ([int], optional): If >0 then, it corresponds to the number of steps to look back. Defaults to None.
prefix (str, optional): The name of the variable in the workspace. Defaults to "attn_".
use_layer_norm (bool, optional): With/without layer normalization. Defaults to False.
"""
agents = []
for k in range(n_layers):
in_prefix = prefix + str(k + 1)
out_prefix = prefix + str(k + 2)
if k == n_layers - 1:
out_prefix = prefix + "out"
if k == 0:
in_prefix = prefix + "in"
agents.append(
TransformerBlockAgent(
embedding_size,
n_heads,
n_steps,
in_prefix + "/x",
out_prefix + "/x",
use_layer_norm=use_layer_norm,
)
)
super().__init__(*agents)