Skip to content

Conversation

@oleksost
Copy link
Contributor

✨ Description

Bugs in the recurrent step path used in generation.

This makes GSM8k generations actually sensical.

πŸ” Type of change

Select all that apply:

  • πŸ› Bug fix (non-breaking change that addresses a specific issue)
  • πŸš€ New feature (non-breaking change that adds functionality)
  • ⚠️ Breaking change (a change that could affect existing functionality)
  • πŸ“ˆ Performance improvement/optimization (improves speed, memory usage, or efficiency)
  • πŸ› οΈ Code refactor (non-functional changes that improve code readability, structure, etc.)
  • πŸ“¦ Dependency bump (updates dependencies, including Dockerfile or package changes)
  • πŸ“ Documentation change (updates documentation, including new content or typo fixes)
  • πŸ”§ Infrastructure/Build change (affects build process, CI/CD, or dependencies)

πŸ“ Changes

List the key changes introduced in this PR:

  1. Change A
  2. Change B

βœ… Checklist

Make sure the following tasks are completed before submitting the PR:

General

  • πŸ“œ I have read and followed the contributing guidelines.
  • 🏷️ I am using a clear and descriptive PR title that summarizes the key change or feature introduced.
  • πŸŽ‰ The functionality is complete, and I have tested the changes.
  • πŸ“ I have updated the documentation if needed.
  • ⚠️ The change does not introduce any new issues (e.g., runtime warnings, type checker errors, linting problems, unhandled edge cases).
  • 🧩 I have commented my code, especially in hard-to-understand areas.

Dependencies and Configuration

  • πŸ‹ I have updated the Docker configuration or dependencies, if applicable.
  • πŸ”„ I have ensured compatibility with the existing setup after dependency changes.

Testing

  • πŸ§ͺ I have added or updated tests to cover my changes.
  • βœ”οΈ New and existing tests pass locally with my changes.
  • 🚦 I have tested these changes on GPUs and verified training stability.
  • πŸ‹οΈ I have tested the changes on realistic training workloads, if applicable.

Performance Impact

  • πŸ“Š I have run benchmarks where applicable to evaluate the performance impact.
  • βœ… The benchmarks show no performance regression.
  • πŸš€ The benchmarks indicate a potential performance improvement.
  • ⚠️ The benchmarks indicate a potential performance degradation.
  • πŸ“ˆ I have provided benchmark results and detailed any performance impact below, if applicable.

πŸ“Š Performance Impact Details

If there is any impact on performance, describe it and provide benchmark results, if applicable:


πŸ—’οΈ Additional Notes

Include any additional context, information, or considerations here, such as known issues, follow-up tasks, or backward compatibility concerns.

@oleksost oleksost requested a review from tscholak January 16, 2026 20:31
Copy link
Collaborator

@tscholak tscholak left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

important fixes!

Comment on lines +1593 to +1594
initial_state=recurrent_state,
output_final_state=past_key_values is not None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great catch!

4. Update state: S = S + k βŠ— delta
5. Output: o = S @ q (scaled)
"""
input_dtype = query.dtype
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is the torch implementation, right? under normal circumstances we wouldn't hit this code path because we're using the FLA kernel, no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The torch fallback is torch_chunk_gated_delta_rule. The _recurrent_gated_delta_rule is used at decoding time in recurrent mode (i.e. after prefill)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there's an FLA kernel for chunk_gated_delta_rule that we can use. I think you saw that too now. thanks for approving!

@tscholak
Copy link
Collaborator

when I looked at these changes, I realized that it doesn't make sense to have a torch fallback. I decided to just remove that code. The outcome is #451

@oleksost oleksost closed this Jan 19, 2026
@oleksost oleksost reopened this Jan 19, 2026
Copy link
Collaborator

@tscholak tscholak left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good stuff

@tscholak tscholak merged commit b5cca16 into main Jan 19, 2026
1 of 2 checks passed
@tscholak tscholak deleted the oo/apriel_modeling_bug branch January 19, 2026 14:50
tscholak added a commit that referenced this pull request Jan 19, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants