Skip to content

Commit b1fb993

Browse files
committed
feat: support rollout importance sampling helper from verl
1 parent b39581a commit b1fb993

File tree

6 files changed

+677
-40
lines changed

6 files changed

+677
-40
lines changed

xtuner/v1/data_proto/utils.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,3 +110,89 @@ def gather_for_sequence_parallel(input, dim: int, sp_group: dist.ProcessGroup):
110110
output = torch.cat(tensor_list, dim=dim).contiguous()
111111

112112
return output
113+
114+
115+
def convert_padded_to_packed(
116+
input: torch.Tensor, num_tokens: torch.Tensor | list, padding_side: str = "right"
117+
) -> torch.Tensor:
118+
"""Convert a padded tensor (B, L, ...) to a packed tensor (1,
119+
sum(num_tokens), ...).
120+
121+
Args:
122+
input: The input tensor to be converted.
123+
num_tokens: The number of tokens of each sequence in the padded input.
124+
"""
125+
if isinstance(num_tokens, torch.Tensor):
126+
num_tokens = num_tokens.tolist()
127+
if padding_side == "right":
128+
return torch.cat([input[i, : num_tokens[i]] for i in range(len(num_tokens))], dim=0).unsqueeze(0)
129+
elif padding_side == "left":
130+
return torch.cat([input[i, -num_tokens[i] :] for i in range(len(num_tokens))], dim=0).unsqueeze(0)
131+
else:
132+
raise ValueError(f"Invalid padding_side: {padding_side}. Must be 'right' or 'left'.")
133+
134+
135+
def convert_packed_to_padded(
136+
input: torch.Tensor, num_tokens: torch.Tensor | list, padding_value: float, padding_side: str = "right"
137+
) -> torch.Tensor:
138+
"""Convert a packed tensor (1, sum(num_tokens), ...) to a padded tensor
139+
(len(num_tokens), max(num_tokens), ...).
140+
141+
Args:
142+
input: The input tensor to be converted.
143+
num_tokens: The number of tokens of each sequence in the padded input.
144+
"""
145+
unpacked_input = unpack_sequence(input, num_tokens) # list of (1, num_tokens[i], ...)
146+
max_length = max(num_tokens)
147+
padded_input = torch.full(
148+
(len(num_tokens), max_length, *input.shape[2:]), padding_value, dtype=input.dtype, device=input.device
149+
)
150+
for i, seq in enumerate(unpacked_input):
151+
if padding_side == "right":
152+
padded_input[i, : num_tokens[i]] = seq[0]
153+
elif padding_side == "left":
154+
padded_input[i, -num_tokens[i] :] = seq[0]
155+
else:
156+
raise ValueError(f"Invalid padding_side: {padding_side}. Must be 'right' or 'left'.")
157+
return padded_input
158+
159+
160+
def masked_sum(
161+
input: torch.Tensor,
162+
mask: torch.Tensor,
163+
axis: int | None = None,
164+
num_tokens: torch.Tensor | list | None = None,
165+
unpack_sequence: bool = False,
166+
) -> torch.Tensor:
167+
"""
168+
Args:
169+
input: The input tensor to be masked.
170+
mask: The mask tensor to be applied.
171+
axis: The dimension along which the tensor should be masked.
172+
num_tokens: The number of tokens of each sequence in the packed input.
173+
unpack_sequence: Whether to unpack the sequence.
174+
"""
175+
if unpack_sequence:
176+
input = convert_packed_to_padded(input, num_tokens, padding_value=0, padding_side="right")
177+
mask = convert_packed_to_padded(mask, num_tokens, padding_value=0, padding_side="right")
178+
valid_values = torch.where(mask.bool(), input, 0.0)
179+
return (valid_values * mask).sum(axis=axis)
180+
181+
182+
def masked_mean(
183+
input: torch.Tensor,
184+
mask: torch.Tensor,
185+
axis: int | None = None,
186+
num_tokens: torch.Tensor | list | None = None,
187+
unpack_sequence: bool = False,
188+
) -> torch.Tensor:
189+
"""
190+
Args:
191+
input: The input tensor to be masked.
192+
mask: The mask tensor to be applied.
193+
axis: The dimension along which the tensor should be masked.
194+
num_tokens: The number of tokens of each sequence in the packed input.
195+
unpack_sequence: Whether to unpack the sequence.
196+
"""
197+
sum = masked_sum(input, mask, axis=axis, num_tokens=num_tokens, unpack_sequence=unpack_sequence)
198+
return sum / (mask.sum(axis=axis) + 1e-8)

xtuner/v1/rl/base/loss.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@
77

88
from xtuner.v1.loss import BaseLossConfig
99
from xtuner.v1.loss.base_loss_ctx import BaseLossContext
10-
10+
from .rollout_is import RolloutImportanceSampling
1111
from ..utils import sp_split
1212

1313

14+
1415
T = TypeVar("T")
1516

1617

@@ -32,7 +33,9 @@ class BaseRLLossConfig(BaseLossConfig):
3233
kl_loss_type (Literal["kl", "k1", "abs", "mse", "k2", "low_var_kl", "k3"] | None):
3334
Type of KL penalty computation method. Different types provide various
3435
regularization behaviors and numerical stability properties. Defaults to None.
35-
36+
rollout_is (RolloutImportanceSampling): Configuration parameters for the rollout importance sampling.
37+
Contains algorithm-specific parameters for rollout importance sampling.
38+
Defaults to RolloutImportanceSampling().
3639
**Abstract Method:**
3740
loss_ctx_cls: Must be implemented by subclasses to return the appropriate
3841
loss context class for the specific RL algorithm.
@@ -72,6 +75,7 @@ class BaseRLLossConfig(BaseLossConfig):
7275
use_kl_loss: bool = False
7376
kl_loss_coef: float = 0.001
7477
kl_loss_type: Literal["kl", "k1", "abs", "mse", "k2", "low_var_kl", "k3"] | None = None
78+
rollout_is: RolloutImportanceSampling = RolloutImportanceSampling()
7579

7680
@property
7781
def loss_ctx_cls(self) -> type[BaseLossContext]:
@@ -86,24 +90,38 @@ class RLLossContextInputItem(BaseModel):
8690
advantages (torch.Tensor): Advantage estimates for the actions taken.
8791
old_logprobs (torch.Tensor | None): Log probabilities from the old policy.
8892
ref_logprobs (torch.Tensor | None): Reference log probabilities for KL penalty, if used.
93+
rollout_logprobs (torch.Tensor | None): Rollout log probabilities from inference engine, used for importance sampling.
94+
is_weights (torch.Tensor | None): Importance sampling weights. If None, importance sampling is not used.
8995
"""
9096

9197
model_config = ConfigDict(title="RLLossContextInputItem", extra="allow", arbitrary_types_allowed=True)
9298
shifted_labels: torch.Tensor
9399
advantages: torch.Tensor
94100
old_logprobs: torch.Tensor | None = None
95101
ref_logprobs: torch.Tensor | None = None
102+
rollout_logprobs: torch.Tensor | None = None
103+
is_weights: torch.Tensor | None = None
96104

97105
def sp_split(self, sp_mesh: DeviceMesh) -> Self:
98106
shifted_labels = sp_split(self.shifted_labels, sp_mesh=sp_mesh, split_dim=1, padding_value=-100)
99107
advantages = sp_split(self.advantages, sp_mesh=sp_mesh, split_dim=1, padding_value=0.0)
108+
if self.rollout_logprobs is not None:
109+
rollout_logprobs = sp_split(self.rollout_logprobs, sp_mesh=sp_mesh, split_dim=1, padding_value=0.0)
110+
else:
111+
rollout_logprobs = None
112+
if self.is_weights is not None:
113+
is_weights = sp_split(self.is_weights, sp_mesh=sp_mesh, split_dim=1, padding_value=1.0)
114+
else:
115+
is_weights = None
100116
# 这里不用对old_logprobs和ref_logprobs进行sp_split,因为他是模型 fwd 生成的
101117
# 模型 fwd 前一定会对 seq_ctx 进行 sp_split
102118
return type(self)(
103119
shifted_labels=shifted_labels,
104120
advantages=advantages,
105121
old_logprobs=self.old_logprobs,
106122
ref_logprobs=self.ref_logprobs,
123+
rollout_logprobs=rollout_logprobs,
124+
is_weights=is_weights,
107125
)
108126

109127
def to(self, device: torch.device | str) -> Self:
@@ -113,4 +131,8 @@ def to(self, device: torch.device | str) -> Self:
113131
self.old_logprobs = self.old_logprobs.to(device)
114132
if self.ref_logprobs is not None:
115133
self.ref_logprobs = self.ref_logprobs.to(device)
134+
if self.rollout_logprobs is not None:
135+
self.rollout_logprobs = self.rollout_logprobs.to(device)
136+
if self.is_weights is not None:
137+
self.is_weights = self.is_weights.to(device)
116138
return self

0 commit comments

Comments
 (0)