[WIP] Replace chunked FLA with recurrent gated delta rule for T=1 decode#18667
[WIP] Replace chunked FLA with recurrent gated delta rule for T=1 decode#18667
Conversation
The chunked FLA pipeline (6 Triton kernels) is overkill for T=1 decode. Replace with plain PyTorch einsum ops that Inductor can fuse: - FLA GPU time: 1.085ms → 0.344ms/step (-68%) - Total GPU time: 12.0ms → 9.0ms/step (-25%) - Export changed to static T=1 with enable_dynamic_shape=False
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18667
Note: Links to docs will display an error until the docs builds have been completed. ❌ 3 New Failures, 2 Unrelated FailuresAs of commit 7dd4280 with merge base 300e368 ( NEW FAILURES - The following jobs have failed:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
examples/models/qwen3_5_moe/model.py
Outdated
| q, k, v, g, beta, self.recurrent_state[:B] | ||
| # Recurrent gated delta rule — single-step update. | ||
| # The model is exported with static T=1 and the C++ runner does | ||
| # token-by-token prefill (enable_dynamic_shape=False), so T is |
There was a problem hiding this comment.
Any impact on prefill performance?
There was a problem hiding this comment.
it will impact prefill performance to only half; the ongoing fix will make sure that prefill will use chunked implementation while decode uses recurrent one.
Move decode/prefill dispatch inside the chunk_gated_delta_rule triton_op instead of using torch.cond at model level. This follows the same pattern as the SDPA triton_op (pow2/non-pow2 dispatch) and avoids torch.cond incompatibility with AOTI's FunctionalTensor pipeline. Changes: - chunk_gated_delta_rule.py: Add fused recurrent Triton kernel for T=1, refactor chunked pipeline into _launch_chunked(), dispatch via Python if inside the @triton_op wrapper - model.py: Remove torch.cond from GatedDeltaNet.forward(), call triton_op directly (dispatch is internal) - export.py: Single-method export with dynamic seq_len dim - main.cpp: Fix create_text_llm_runner API signature
Only chunk_gated_delta_rule.py needs modification — dispatch logic is internal to the triton_op, no model/export/runner changes needed.
- test_recurrent_t1: verify T=1 recurrent kernel against FLA naive
reference across all FLA test configs
- test_dispatch_multiple_seq_lengths: verify correctness for
T in {1, 2, 32, 63, 64, 65, 128, 256}, covering both dispatch
paths and chunk boundary edge cases
- Grid changed from (B*H,) to (V//BV, B*H) — 4x more blocks, better SM occupancy (128 blocks vs 32 on A100) - BV reduced from 128 to 32 — lower register pressure, no spilling - Removed unnecessary .contiguous() copies on squeezed inputs - Removed debug print from triton_op dispatch - GPU kernel time: 6us (3.47x faster than Inductor-fused native ops)
The chunked FLA pipeline (6 Triton kernels) is overkill for T=1 decode. Replace with plain PyTorch einsum ops that Inductor can fuse and locally benchmark result:
perf boost from 77.7 token/s to 89.9 token/s, but negative impact on prefill perf
still ongoing to solve prefill issue.