Skip to content

请教如何理解函数_transpose_shift #34

@yangshuodelove

Description

@yangshuodelove

请教relative_transformer.py中_transpose_shift函数中的几个问题:
(1)它实现了矩阵的什么变换(如移动、旋转等)?
(2)怎么理解它是如何实现这种变换的呢?
(3)倒数第3行中的indice为什么只选取奇数行呢?
(4)_transpose_shift函数与_shift有什么区别,又有什么联系?
十分感谢


def _transpose_shift(self, E): 
       #E=[B,N,L,2*L]=[bsz, head, max_len, 2max_len] 如[2, 4, 68, 136];
        bsz, n_head, max_len, _ = E.size()
        zero_pad = E.new_zeros(bsz, n_head, max_len, 1)
        E = torch.cat([E, zero_pad], dim=-1).view(bsz, n_head, -1, max_len) # [B,N,2L,L] 
        indice = (torch.arange(max_len) * 2 + 1).to(E.device) # 选取是奇数行:[1,3,5...135]
        E = E.index_select(index=indice, dim=-2).transpose(-1, -2)  
        return E

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions