Skip to content

support cp ,fix qwen3.5 gdn sp#138

Open
meichangsu1 wants to merge 3 commits intomodelscope:mainfrom
meichangsu1:fsdp_cp_ljl
Open

support cp ,fix qwen3.5 gdn sp#138
meichangsu1 wants to merge 3 commits intomodelscope:mainfrom
meichangsu1:fsdp_cp_ljl

Conversation

@meichangsu1
Copy link
Copy Markdown
Collaborator

@meichangsu1 meichangsu1 commented Apr 2, 2026

PR type

  • Bug Fix
  • New Feature
  • Document Updates
  • More Models or Datasets Support

PR information

This PR adds context parallel and Qwen3.5 Gated DeltaNet sequence parallel support to the transformers stack, and refactors sequence parallel into a package-based implementation.

Main changes:

  • Refactor sequence_parallel.py into sequence_parallel/ and add shared utilities.
  • Add derived ring / zigzag ring attention support for CP + SP.
  • Add Qwen3.5 linear attention SP support in linear_attention_sp.py;Ring attention is not supported for this path yet.
  • Update transformers model / processor paths to work with the new SP+CP flow.
  • Adjust loss metric aggregation for Ulysses replicated loss behavior.
  • Update cookbook examples for sp_fsdp_dense.
  • Add test coverage for:
    • Qwen3.5 linear attention SP alignment
    • sequence parallel + context parallel behavior
  • Remove outdated tests/moe/test_expert_parallel_qwen3_fsdp_sp.py.

Experiment results

@meichangsu1 meichangsu1 changed the title Fsdp cp ljl support cp ,fix qwen3.5 gdn sp Apr 2, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request significantly enhances sequence parallelism support by implementing ZigZag Ring Attention for long-sequence training and Ulysses-style sequence parallelism for Qwen3.5 linear attention. It also introduces multimodal deepstack patching for Qwen3-VL and refactors the SequenceParallel strategy to better handle complex device meshes and packed/varlen inputs. Feedback focuses on improving code maintainability and robustness, specifically by grouping attributes in the SequenceParallel constructor, removing redundant logic and unused imports, replacing deprecated inspection methods, and centralizing duplicated loss-gathering logic.

Comment on lines +37 to 53
self.seq_world_size = None
self.sp_world_size = None
self.rp_world_size = None
self.dp_world_size = None
self.world_size = None
self.attn_implementation = None
self.model_dtype = None
self.tokenizer = None
self.device_mesh = None
self._sp_group = None
self._rp_group = None
self._data_rank_group = None
self._sp_rank = 0
self._rp_rank = 0
self.num_heads = None
self.causal_mask_func = None
self.extra_kwargs = {}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The init method is becoming quite large with many attributes. Consider grouping related attributes into a dataclass or a separate configuration object to improve maintainability.

Comment on lines +236 to +239
if query.shape[2] != total_tokens:
raise ValueError('Packed/varlen flash_attention_2 expects query sequence length to match '
f'cu_seqlens total tokens, got query_seq_len={query.shape[2]} '
f'and cu_seqlens_total={total_tokens}.')
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The check if world_size and world_size > 1 is redundant if world_size is guaranteed to be an integer. If it can be None, consider a more explicit check or default value handling.

Comment on lines +481 to +484
if self.rp_world_size > 1:
attn_impl = getattr(model.config, '_attn_implementation', None)
if attn_impl != 'flash_attention_2':
raise NotImplementedError('Derived ring attention only supports flash_attention_2 backend.')
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The check for flash_attention_2 backend should ideally be done using a constant or a centralized configuration check to avoid hardcoded strings.

@@ -0,0 +1,283 @@
import os
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The import import os is unused in this file. Please remove it.

Comment on lines +115 to +123
@cache
def _get_default_args(func):
spec = inspect.getfullargspec(func)
defaults = spec.defaults if spec.defaults is not None else ()
padded_defaults = (None, ) * (len(spec.args) - len(defaults)) + defaults
args = dict(zip(spec.args, padded_defaults))
if 'softcap' in args:
args['softcap'] = 0.0
return args
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The _get_default_args function uses inspect.getfullargspec which is deprecated in newer Python versions. Consider using inspect.signature instead.

Comment on lines +490 to +491
if self.sp_strategy is not None:
loss_inputs, loss_outputs = self.sp_strategy.gather_loss_tensors(inputs, outputs)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for gathering loss tensors is duplicated or very similar to logic in other parts of the codebase. Consider centralizing this loss gathering logic to avoid drift.

@meichangsu1 meichangsu1 changed the title support cp ,fix qwen3.5 gdn sp support cp ,fix qwen3.5 gdn sp Apr 2, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant