-
Notifications
You must be signed in to change notification settings - Fork 55
Open
Description
请教relative_transformer.py中_transpose_shift函数中的几个问题:
(1)它实现了矩阵的什么变换(如移动、旋转等)?
(2)怎么理解它是如何实现这种变换的呢?
(3)倒数第3行中的indice为什么只选取奇数行呢?
(4)_transpose_shift函数与_shift有什么区别,又有什么联系?
十分感谢
def _transpose_shift(self, E):
#E=[B,N,L,2*L]=[bsz, head, max_len, 2max_len] 如[2, 4, 68, 136];
bsz, n_head, max_len, _ = E.size()
zero_pad = E.new_zeros(bsz, n_head, max_len, 1)
E = torch.cat([E, zero_pad], dim=-1).view(bsz, n_head, -1, max_len) # [B,N,2L,L]
indice = (torch.arange(max_len) * 2 + 1).to(E.device) # 选取是奇数行:[1,3,5...135]
E = E.index_select(index=indice, dim=-2).transpose(-1, -2)
return E
Metadata
Metadata
Assignees
Labels
No labels