Skip to content

[WIP] Replace chunked FLA with recurrent gated delta rule for T=1 decode#18667

Draft
Gasoonjia wants to merge 9 commits intomainfrom
recurrent-fla
Draft

[WIP] Replace chunked FLA with recurrent gated delta rule for T=1 decode#18667
Gasoonjia wants to merge 9 commits intomainfrom
recurrent-fla

Conversation

@Gasoonjia
Copy link
Copy Markdown
Contributor

@Gasoonjia Gasoonjia commented Apr 2, 2026

The chunked FLA pipeline (6 Triton kernels) is overkill for T=1 decode. Replace with plain PyTorch einsum ops that Inductor can fuse and locally benchmark result:

  • FLA GPU time: 1.085ms → 0.344ms/step (-68%)
  • Total GPU time: 12.0ms → 9.0ms/step (-25%)

perf boost from 77.7 token/s to 89.9 token/s, but negative impact on prefill perf
still ongoing to solve prefill issue.

The chunked FLA pipeline (6 Triton kernels) is overkill for T=1 decode.
Replace with plain PyTorch einsum ops that Inductor can fuse:
- FLA GPU time: 1.085ms → 0.344ms/step (-68%)
- Total GPU time: 12.0ms → 9.0ms/step (-25%)
- Export changed to static T=1 with enable_dynamic_shape=False
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Apr 2, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18667

Note: Links to docs will display an error until the docs builds have been completed.

❌ 3 New Failures, 2 Unrelated Failures

As of commit 7dd4280 with merge base 300e368 (image):

NEW FAILURES - The following jobs have failed:

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 2, 2026
@github-actions
Copy link
Copy Markdown

github-actions bot commented Apr 2, 2026

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

q, k, v, g, beta, self.recurrent_state[:B]
# Recurrent gated delta rule — single-step update.
# The model is exported with static T=1 and the C++ runner does
# token-by-token prefill (enable_dynamic_shape=False), so T is
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any impact on prefill performance?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it will impact prefill performance to only half; the ongoing fix will make sure that prefill will use chunked implementation while decode uses recurrent one.

Move decode/prefill dispatch inside the chunk_gated_delta_rule triton_op
instead of using torch.cond at model level. This follows the same pattern
as the SDPA triton_op (pow2/non-pow2 dispatch) and avoids torch.cond
incompatibility with AOTI's FunctionalTensor pipeline.

Changes:
- chunk_gated_delta_rule.py: Add fused recurrent Triton kernel for T=1,
  refactor chunked pipeline into _launch_chunked(), dispatch via Python
  if inside the @triton_op wrapper
- model.py: Remove torch.cond from GatedDeltaNet.forward(), call
  triton_op directly (dispatch is internal)
- export.py: Single-method export with dynamic seq_len dim
- main.cpp: Fix create_text_llm_runner API signature
Only chunk_gated_delta_rule.py needs modification — dispatch logic
is internal to the triton_op, no model/export/runner changes needed.
- test_recurrent_t1: verify T=1 recurrent kernel against FLA naive
  reference across all FLA test configs
- test_dispatch_multiple_seq_lengths: verify correctness for
  T in {1, 2, 32, 63, 64, 65, 128, 256}, covering both dispatch
  paths and chunk boundary edge cases
- Grid changed from (B*H,) to (V//BV, B*H) — 4x more blocks, better SM
  occupancy (128 blocks vs 32 on A100)
- BV reduced from 128 to 32 — lower register pressure, no spilling
- Removed unnecessary .contiguous() copies on squeezed inputs
- Removed debug print from triton_op dispatch
- GPU kernel time: 6us (3.47x faster than Inductor-fused native ops)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/cuda CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants