@@ -110,3 +110,89 @@ def gather_for_sequence_parallel(input, dim: int, sp_group: dist.ProcessGroup):
110110 output = torch .cat (tensor_list , dim = dim ).contiguous ()
111111
112112 return output
113+
114+
115+ def convert_padded_to_packed (
116+ input : torch .Tensor , num_tokens : torch .Tensor | list , padding_side : str = "right"
117+ ) -> torch .Tensor :
118+ """Convert a padded tensor (B, L, ...) to a packed tensor (1,
119+ sum(num_tokens), ...).
120+
121+ Args:
122+ input: The input tensor to be converted.
123+ num_tokens: The number of tokens of each sequence in the padded input.
124+ """
125+ if isinstance (num_tokens , torch .Tensor ):
126+ num_tokens = num_tokens .tolist ()
127+ if padding_side == "right" :
128+ return torch .cat ([input [i , : num_tokens [i ]] for i in range (len (num_tokens ))], dim = 0 ).unsqueeze (0 )
129+ elif padding_side == "left" :
130+ return torch .cat ([input [i , - num_tokens [i ] :] for i in range (len (num_tokens ))], dim = 0 ).unsqueeze (0 )
131+ else :
132+ raise ValueError (f"Invalid padding_side: { padding_side } . Must be 'right' or 'left'." )
133+
134+
135+ def convert_packed_to_padded (
136+ input : torch .Tensor , num_tokens : torch .Tensor | list , padding_value : float , padding_side : str = "right"
137+ ) -> torch .Tensor :
138+ """Convert a packed tensor (1, sum(num_tokens), ...) to a padded tensor
139+ (len(num_tokens), max(num_tokens), ...).
140+
141+ Args:
142+ input: The input tensor to be converted.
143+ num_tokens: The number of tokens of each sequence in the padded input.
144+ """
145+ unpacked_input = unpack_sequence (input , num_tokens ) # list of (1, num_tokens[i], ...)
146+ max_length = max (num_tokens )
147+ padded_input = torch .full (
148+ (len (num_tokens ), max_length , * input .shape [2 :]), padding_value , dtype = input .dtype , device = input .device
149+ )
150+ for i , seq in enumerate (unpacked_input ):
151+ if padding_side == "right" :
152+ padded_input [i , : num_tokens [i ]] = seq [0 ]
153+ elif padding_side == "left" :
154+ padded_input [i , - num_tokens [i ] :] = seq [0 ]
155+ else :
156+ raise ValueError (f"Invalid padding_side: { padding_side } . Must be 'right' or 'left'." )
157+ return padded_input
158+
159+
160+ def masked_sum (
161+ input : torch .Tensor ,
162+ mask : torch .Tensor ,
163+ axis : int | None = None ,
164+ num_tokens : torch .Tensor | list | None = None ,
165+ unpack_sequence : bool = False ,
166+ ) -> torch .Tensor :
167+ """
168+ Args:
169+ input: The input tensor to be masked.
170+ mask: The mask tensor to be applied.
171+ axis: The dimension along which the tensor should be masked.
172+ num_tokens: The number of tokens of each sequence in the packed input.
173+ unpack_sequence: Whether to unpack the sequence.
174+ """
175+ if unpack_sequence :
176+ input = convert_packed_to_padded (input , num_tokens , padding_value = 0 , padding_side = "right" )
177+ mask = convert_packed_to_padded (mask , num_tokens , padding_value = 0 , padding_side = "right" )
178+ valid_values = torch .where (mask .bool (), input , 0.0 )
179+ return (valid_values * mask ).sum (axis = axis )
180+
181+
182+ def masked_mean (
183+ input : torch .Tensor ,
184+ mask : torch .Tensor ,
185+ axis : int | None = None ,
186+ num_tokens : torch .Tensor | list | None = None ,
187+ unpack_sequence : bool = False ,
188+ ) -> torch .Tensor :
189+ """
190+ Args:
191+ input: The input tensor to be masked.
192+ mask: The mask tensor to be applied.
193+ axis: The dimension along which the tensor should be masked.
194+ num_tokens: The number of tokens of each sequence in the packed input.
195+ unpack_sequence: Whether to unpack the sequence.
196+ """
197+ sum = masked_sum (input , mask , axis = axis , num_tokens = num_tokens , unpack_sequence = unpack_sequence )
198+ return sum / (mask .sum (axis = axis ) + 1e-8 )
0 commit comments