Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions trinity/buffer/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
"reward_shaping_mapper": "trinity.buffer.operators.mappers.reward_shaping_mapper.RewardShapingMapper",
"pass_rate_calculator": "trinity.buffer.operators.mappers.pass_rate_calculator.PassRateCalculator",
"data_juicer": "trinity.buffer.operators.data_juicer_operator.DataJuicerOperator",
"invalid_reward_filter": "trinity.buffer.operators.filters.reward_filter.InvalidRewardFilter",
},
)

Expand Down
16 changes: 16 additions & 0 deletions trinity/buffer/operators/filters/reward_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,19 @@ def process(self, exps: List[Experience]) -> Tuple[List[Experience], dict]:
final_count = len(result_exps)
metrics["filtered_count"] = original_count - final_count
return result_exps, metrics


class InvalidRewardFilter(ExperienceOperator):
"""
Filters out experiences with invalid reward values.
Note: This operator assumes that rewards are already computed and stored in the
Experience object.Any experience with a missing (`None`) or invalid (`NaN`)
reward is removed to prevent low-quality data from entering the training
pipeline.
"""

def process(self, exps: List[Experience]) -> Tuple[List[Experience], dict]:
kept = [e for e in exps if e.reward is not None and e.reward == e.reward]

return kept, {"filtered_count": len(exps) - len(kept)}