我想使用 PyTorch 的 nn.MultiheadAttention 但它不起作用。
我只是想使用pytorch的功能来手动计算注意力的例子
当我尝试运行这个例子时总是遇到错误。
import torch.nn as nn
embed_dim = 4
num_heads = 1
x = [
[1, 0, 1, 0], # Input 1
[0, 2, 0, 2], # Input 2
[1, 1, 1, 1] # Input 3
]
x = torch.tensor(x, dtype=torch.float32)
w_key = [
[0, 0, 1],
[1, 1, 0],
[0, 1, 0],
[1, 1, 0]
]
w_query = [
[1, 0, 1],
[1, 0, 0],
[0, 0, 1],
[0, 1, 1]
]
w_value = [
[0, 2, 0],
[0, 3, 0],
[1, 0, 3],
[1, 1, 0]
]
w_key = torch.tensor(w_key, dtype=torch.float32)
w_query = torch.tensor(w_query, dtype=torch.float32)
w_value = torch.tensor(w_value, dtype=torch.float32)
keys = x @ w_key
querys = x @ w_query
values = x @ w_value
multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
attn_output, attn_output_weights = multihead_attn(querys, keys, values)
Run Code Online (Sandbox Code Playgroud)
小智 8
尝试这个。
首先,你的 x 是一个 (3x4) 矩阵。所以你需要一个 (4x4) 的权重矩阵。
似乎 nn.MultiheadAttention 仅支持批处理模式,尽管文档说它支持非批处理输入。因此,让我们通过 使您的一个数据点处于批处理模式.unsqueeze(0)
。
embed_dim = 4
num_heads = 1
x = [
[1, 0, 1, 0], # Seq 1
[0, 2, 0, 2], # Seq 2
[1, 1, 1, 1] # Seq 3
]
x = torch.tensor(x, dtype=torch.float32)
w_key = [
[0, 0, 1, 1],
[1, 1, 0, 1],
[0, 1, 0, 1],
[1, 1, 0, 1]
]
w_query = [
[1, 0, 1, 1],
[1, 0, 0, 1],
[0, 0, 1, 1],
[0, 1, 1, 1]
]
w_value = [
[0, 2, 0, 1],
[0, 3, 0, 1],
[1, 0, 3, 1],
[1, 1, 0, 1]
]
w_key = torch.tensor(w_key, dtype=torch.float32)
w_query = torch.tensor(w_query, dtype=torch.float32)
w_value = torch.tensor(w_value, dtype=torch.float32)
keys = (x @ w_key).unsqueeze(0) # to batch mode
querys = (x @ w_query).unsqueeze(0)
values = (x @ w_value).unsqueeze(0)
multihead_attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
attn_output, attn_output_weights = multihead_attn(querys, keys, values)
Run Code Online (Sandbox Code Playgroud)
归档时间: |
|
查看次数: |
8685 次 |
最近记录: |