Skip to content

Add Ulysses attention#376

Open
csgoogle wants to merge 4 commits intomainfrom
ulysses-attention-benchmark
Open

Add Ulysses attention#376
csgoogle wants to merge 4 commits intomainfrom
ulysses-attention-benchmark

Conversation

@csgoogle
Copy link
Copy Markdown
Collaborator

@csgoogle csgoogle commented Apr 13, 2026

Summary

This PR adds Ulysses attention support for WAN TPU inference in MaxDiffusion and documents how to enable it.

Design Doc: https://docs.google.com/document/d/1_hrPGaIwj84iF8vFJrcdKdmwfKJPvW6O2Sy5ftLVn60/edit?usp=sharing&resourcekey=0-p0zkvHa_NJDwHPqLwNxNCg

What Changed

  • added a TPU Ulysses attention path for WAN that performs sequence-to-head all_to_all before local splash attention and restores the original layout afterward
  • refactored the TPU flash/Ulysses block-size resolution logic so both paths use the same helper
  • added fail fast with a ValueError when the attention head count is not divisible by the context shard count
  • added tests
  • updated the README to document Ulysses support for WAN inference, including the required attention="ulysses" and ici_context_parallelism>1 override pattern

Performance

TPU v6e

Wan2.2 I2V

Setup:

  • model: Wan-AI/Wan2.2-I2V-A14B-Diffusers
  • hardware: 8x TPU v6 lite
  • parallelism: dp=2, cp=4, fsdp=1, tp=1
  • timing config: 40 inference steps, 81 frames, 720x1280
Global Batch Size Flash Ulysses Delta
1 285.56s 251.45s -11.9%
2 533.67s 491.22s -8.0%

Wan2.2 T2V

Setup:

  • model: Wan-AI/Wan2.2-T2V-A14B-Diffusers
  • hardware: 8x TPU v6e
  • parallelism: dp=2, cp=4, fsdp=1, tp=1
  • timing config: 40 inference steps, 81 frames, 720x1280
Global Batch Size Flash Ulysses Delta
1 275.54s 246.90s -10.39%
2 535.40s 480.24s -10.30%

TPU v7x

Wan2.2 I2V

Setup:

  • model: Wan-AI/Wan2.2-I2V-A14B-Diffusers
  • hardware: TPU v7-8 (8 chips)
  • parallelism: ici_context_parallelism=4, ici_data_parallelism=2
  • timing config: 40 inference steps, 81 frames, 720x1280
  • flash block sizes: block_q=2048, block_kv=2048, block_kv_compute=1024
Global Batch Size Flash Ulysses Delta
1 209s 199s -5%
2 414s 394s -5%
4 829s 780s -6%

@github-actions
Copy link
Copy Markdown

@csgoogle csgoogle changed the title working code Add Ulysses attention Apr 15, 2026
@csgoogle csgoogle marked this pull request as ready for review April 15, 2026 09:18
@csgoogle csgoogle requested a review from entrpn as a code owner April 15, 2026 09:18
Comment thread src/maxdiffusion/models/attention_flax.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants