Skip to content

Ring attention integration and other optimizations#359

Open
eltsai wants to merge 8 commits intomainfrom
elisatsai_ring_attention
Open

Ring attention integration and other optimizations#359
eltsai wants to merge 8 commits intomainfrom
elisatsai_ring_attention

Conversation

@eltsai
Copy link
Copy Markdown
Collaborator

@eltsai eltsai commented Mar 16, 2026

In this PR, we integrated tokamax ring attention kernels for WAN models. Below are the main changes made:

  1. Added ring attention kernel and splash attention kernel under src/maxdiffusion/kernels/splash_attention/. Here is the doc for the modification we made: Ring Attention Kernel Precision Issue. Modified attention_flax.py to support tokamax_ring

  2. JITted VAE and sharded VAE: added new config vae_spatial (default to 1) to let users decide how to shard VAE.

  3. Xprof: modified profiler code to actually use profiler_steps (for example profiler_steps=3) instead of profiling the entire generation

@eltsai eltsai force-pushed the elisatsai_ring_attention branch 3 times, most recently from f9b9f72 to b193301 Compare March 30, 2026 16:30
@eltsai
Copy link
Copy Markdown
Collaborator Author

eltsai commented Mar 30, 2026

This PR is mainly for adding ring attention kernels for WAN models into maxdiffusion.

Core Features

  1. Ring Attention Implementation
    • Added tokamax ring attention as an optional attention kernel for WAN 2.1 and WAN 2.2
    • Enable by setting attention==tokamax_ring
  2. VAE Optimization
    • JIT compilation and sharding of VAE (Variational Autoencoder)
    • use vae_spatial to shard VAE, default to tpu_num
  3. RoPE sharding: shard rotary_emb to match the sharding of QKV

Other Miscs:

  1. Refactored WAN VAE invocation paths so pipelines use a more consistent VAE call/config flow across runs.
  2. Cleaned up profiling and xprof behavior so we are actually using profiler_steps, instead of capturing the entire run.
  3. Removed incorrect warnings
  4. Added missing init.py for splash_attention package

@eltsai
Copy link
Copy Markdown
Collaborator Author

eltsai commented Apr 14, 2026

e2e Generation Time

Below are the e2e generation time for tokamax flash attention and ring attention (best results are in bold):

Accelerator Model Attention Type Inference Steps Sharding e2e Generation Time
v7x-8 WAN 2.1 Tokamax Flash 50 dp2-fsdp1-context4-tp1 264.2
v7x-8 WAN 2.1 Tokamax Ring 50 dp2-fsdp1-context4-tp1 252.4
v7x-8 WAN 2.2 Tokamax Flash 40 dp2-fsdp1-context4-tp1 212.7
v7x-8 WAN 2.2 Tokamax Ring 40 dp2-fsdp1-context4-tp1 201.7
Accelerator Model Attention Type Inference Steps Sharding e2e Generation Time
v7x-16 WAN 2.1 Tokamax Flash 50 dp2-fsdp1-context8-tp1 146.6
v7x-16 WAN 2.1 Tokamax Ring 50 dp2-fsdp1-context8-tp1 136.5
v7x-16 WAN 2.2 Tokamax Flash 40 dp2-fsdp1-context8-tp1 117.8
v7x-16 WAN 2.2 Tokamax Ring 40 dp2-fsdp1-context8-tp1 137.5

@eltsai
Copy link
Copy Markdown
Collaborator Author

eltsai commented Apr 14, 2026

Optimizations (WAN 2.1):

Attempts v7x-8 v7x-16
flash 325 184.0
tokamax_flash 324.93 169.2
ring 336 173.3
tokamax_ring 307.4 150.2
tokamax_ring (best config) 299.1 142.3
tokamax_flash + VAE (jit and sharded) 292.7 149.3
tokamax_flash + VAE (jit and sharded on 8 devices) + ROPE sharding 264.2 137.4
tokamax_ring + VAE (jit and sharded) 286.7 138.1
tokamax_ring + VAE (jit and sharded on 4 devices) + ROPE sharding 252.4 136.5

For WAN 2.1

  • By changing from native flash to tokamax_ring, we gain about 8% (8 TPUs) to 29% (16 TPUs) speed
  • By adding jitted VAE and VAE temporal sharding, we have a 4% speed gain

@eltsai
Copy link
Copy Markdown
Collaborator Author

eltsai commented Apr 14, 2026

Best Configs

Below are the configs I used to get the above statistics for 720p video:

Model TPU Count Attention Type Sharding E2E Time (sec)
WAN 2.1 8 TPU Tokamax Ring dp2-fsdp1-context4-tp1 252.4
WAN 2.1 16 TPU Tokamax Ring dp2-fsdp1-context8-tp1 136.5
WAN 2.2 8 TPU Tokamax Ring dp2-fsdp1-context4-tp1 201.7
WAN 2.2 16 TPU Tokamax Flash dp2-fsdp1-context8-tp1 117.8

With tile size:

{"block_q": 2048, "block_kv_compute": 1024, "block_kv": 2048, "block_q_dkv": 2048, "block_kv_dkv": 2048, "block_kv_dkv_compute": 2048, "use_fused_bwd_kernel": true}

@eltsai
Copy link
Copy Markdown
Collaborator Author

eltsai commented Apr 14, 2026

Example output videos for tokamax ring attention 8 TPUs:

  1. WAN 2.1 Google Drive Folder
  2. WAN 2.2 Google Drive Folder

@eltsai eltsai marked this pull request as ready for review April 14, 2026 22:53
@eltsai eltsai requested a review from entrpn as a code owner April 14, 2026 22:53
@eltsai eltsai marked this pull request as draft April 14, 2026 22:54
@eltsai eltsai force-pushed the elisatsai_ring_attention branch from fbc6551 to 7bfc64a Compare April 15, 2026 00:12
@eltsai
Copy link
Copy Markdown
Collaborator Author

eltsai commented Apr 15, 2026

Pytests passed, so I am squashing all commits into one

@eltsai
Copy link
Copy Markdown
Collaborator Author

eltsai commented Apr 15, 2026

Scripts to re-produce the e2e generation time:

  1. wan22_8tpu_tring.sh
  2. wan21_16tpu_tring.sh
  3. wan22_8tpu_tring.sh
  4. wan22_16tpu_tring.sh (tflash is the fastest for WAN 2.2 16 TPUs)

@eltsai eltsai marked this pull request as ready for review April 15, 2026 00:22
@eltsai eltsai self-assigned this Apr 15, 2026
Comment thread dependencies/requirements/generated_requirements/requirements.txt
Comment thread src/maxdiffusion/models/wan/wan_utils.py Outdated
Comment thread src/maxdiffusion/models/modeling_flax_utils.py
Comment thread src/maxdiffusion/pipelines/wan/wan_pipeline.py Outdated
Comment thread docker_build_dependency_image.sh Outdated
Comment thread maxdiffusion_dependencies.Dockerfile Outdated
Comment thread requirements.txt Outdated
Comment thread requirements_with_jax_ai_image.txt Outdated
@eltsai
Copy link
Copy Markdown
Collaborator Author

eltsai commented Apr 15, 2026

Resolved comments, the one main change is moving VAE sharding to config files.

New run shows similar e2e time to above. WAN 2.1 8 TPU tring logs: gpaste.

==================================================
  TIMING SUMMARY
==================================================
  Load (checkpoint):     173.3s
  Compile:               252.6s
  ────────────────────────────────────────
  Inference:             251.3s
==================================================

Comment thread docker_build_dependency_image.sh Outdated
Comment thread docker_build_dependency_image.sh Outdated
Comment thread maxdiffusion_jax_ai_image_tpu.Dockerfile Outdated
Comment thread setup.cfg Outdated
Comment thread setup.py Outdated
Comment thread setup.sh Outdated
Comment thread src/maxdiffusion/models/wan/autoencoder_kl_wan.py
Comment thread src/maxdiffusion/tests/wan_vae_test.py Outdated
Comment thread src/maxdiffusion/tests/wan_vae_test.py Outdated
Comment thread src/maxdiffusion/max_utils.py Outdated
Comment thread src/maxdiffusion/models/wan/autoencoder_kl_wan.py Outdated
Comment thread src/maxdiffusion/models/wan/autoencoder_kl_wan.py Outdated
Comment thread src/maxdiffusion/models/wan/autoencoder_kl_wan.py Outdated
Comment thread src/maxdiffusion/models/wan/autoencoder_kl_wan.py
@eltsai
Copy link
Copy Markdown
Collaborator Author

eltsai commented Apr 15, 2026

Found a perf diff caused by merging new commits: WAN 2.1 stays the same, WAN 2.2 has a ~15 sec increase in e2e time.

It was caused by adding attention_mask=encoder_attention_mask by default in attention_flax.py. For T2V videos, the mask is all 1s. Passing this mask would incur new padding overheads, so I added a conditional check to pass None when it is T2V models.

Now the inference time is back to previous ~200 sec:

==================================================
  TIMING SUMMARY
==================================================
  Load (checkpoint):     135.1s
  Compile:               200.1s
  ────────────────────────────────────────
  Inference:             198.1s
==================================================

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants