Skip to content

Commit 83db4bc

Browse files
committed
rebase main
1 parent 62094f2 commit 83db4bc

File tree

3 files changed

+16
-12
lines changed

3 files changed

+16
-12
lines changed

xtuner/v1/ray/environment/single_turn_env.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -64,16 +64,13 @@ async def generate(
6464
if self.rollout_controller:
6565
# 在env中对输入的数据进行转换,是为了支持rollout_controller单独作为rollout engine使用,使各个模块进行解耦
6666
# 每个模块返回独立的data item, 在env中进行更新
67-
response_futures = []
68-
for sample in group_data_items:
69-
extra_info = sample.data.extra_info if hasattr(sample.data, "extra_info") else {}
70-
extra_info.update({"action_id": sample.uid.action_id})
71-
response_future = self.rollout_controller.rollout.remote(
67+
response_future = [
68+
self.rollout_controller.rollout.remote(
7269
prompt=sample.data.messages,
7370
input_ids=sample.data.input_ids,
7471
sample_params=sample_params,
7572
extra_params=extra_params,
76-
extra_info=extra_info,
73+
extra_info=sample.data.extra_info,
7774
)
7875
for sample in group_data_items
7976
]
@@ -129,4 +126,4 @@ async def run(
129126
for _ in group_data_items
130127
]
131128
group_data_items = update_dataflow_item(group_data_items, "env.judger", judger_responses)
132-
return group_data_items
129+
return group_data_items

xtuner/v1/rl/base/rollout_is.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
import numpy as np
4242
import torch
4343
from pydantic import BaseModel, ConfigDict
44+
import torch.distributed as dist
4445

4546
from xtuner.v1.data_proto.utils import convert_packed_to_padded, convert_padded_to_packed, masked_mean, masked_sum
4647

@@ -521,14 +522,20 @@ def compute_mismatch_metrics(
521522
return metrics
522523

523524

524-
def merge_rollout_is_metrics(rollout_is_metrics: list[dict[str, float]]) -> dict[str, float]:
525+
def merge_rollout_is_metrics(rollout_is_metrics: list[dict[str, float]], device="cuda") -> dict[str, float]:
525526
metrics = {}
526527
for key in rollout_is_metrics[0].keys():
527528
all_values = [m[key] for m in rollout_is_metrics]
528529
if "max" in key:
529-
metrics[key] = np.max(all_values)
530+
max_value = torch.tensor(all_values).max().to(torch.float32).to(device)
531+
dist.all_reduce(max_value, op=dist.ReduceOp.MAX)
532+
metrics[key] = max_value.item()
530533
elif "min" in key:
531-
metrics[key] = np.min(all_values)
534+
min_value = torch.tensor(all_values).min().to(torch.float32).to(device)
535+
dist.all_reduce(min_value, op=dist.ReduceOp.MIN)
536+
metrics[key] = min_value.item()
532537
else:
533-
metrics[key] = np.mean(all_values)
538+
mean_value = torch.tensor(all_values).mean().to(torch.float32).to(device)
539+
dist.all_reduce(mean_value, op=dist.ReduceOp.AVG)
540+
metrics[key] = mean_value.item()
534541
return metrics

xtuner/v1/rl/base/worker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,7 @@ def fit(self, data_batches: list[WorkerInputItem], rollout_idx: int):
408408
logger_msg = f"Rollout {rollout_idx}: "
409409

410410
if len(all_rollout_is_metrics) > 0:
411-
rollout_is_metrics = merge_rollout_is_metrics(all_rollout_is_metrics)
411+
rollout_is_metrics = merge_rollout_is_metrics(all_rollout_is_metrics, DEVICE)
412412
logger_msg += f"\n\nrollout importance sampling metrics:\n{json.dumps(rollout_is_metrics, indent=4)}"
413413

414414
sum_entropy = cast(torch.Tensor, sum_entropy)

0 commit comments

Comments
 (0)