我试图将一系列 m 个输入序列(每个有 n 个项目(s m,n))转换为一个张量,其中 t p,q,r = s p,q s p,r
虽然我的代码确实有效,但我觉得必须有更好的解决方案。这就是我得到的。
# nseqs is the number of sequences
# seq_length is the sequence length
# seq is a list of sequences
output = np.empty((nseqs, seq_length, seq_length))
for n in range(nseqs):
for i, j in enumerate(seq[n]):
output[n, i, :] = j
output[n, :, :] *= output[n, :, :].T
Run Code Online (Sandbox Code Playgroud)
更重要的是,有没有一种方法可以重新调整output张量,以便乘法阶段在单个步骤中完成,而无需这些循环?
通过广播可以达到预期的效果。第一个维度对齐,其他维度可以使用引入的单位维度强制转换为正确的形状None:
t = s[:, :, None] * s[:, None, :]
Run Code Online (Sandbox Code Playgroud)
您可以通过广播分配输出元素,然后仅对最后一个维度进行转置,从而实现无需循环的原始方法:
output = np.empty((seq.shape[0], seq.shape[1], seq.shape[1]), seq.dtype)
output[:] = seq[..., None]
output *= output.transpose(0, 2, 1)
Run Code Online (Sandbox Code Playgroud)