Ring attention integration and other optimizations#359
Conversation
f9b9f72 to
b193301
Compare
|
This PR is mainly for adding ring attention kernels for WAN models into maxdiffusion. Core Features
Other Miscs:
|
e2e Generation TimeBelow are the e2e generation time for tokamax flash attention and ring attention (best results are in bold):
|
|
Optimizations (WAN 2.1):
For WAN 2.1
|
Best ConfigsBelow are the configs I used to get the above statistics for 720p video:
With tile size: |
|
Example output videos for |
fbc6551 to
7bfc64a
Compare
|
Pytests passed, so I am squashing all commits into one |
|
Scripts to re-produce the e2e generation time:
|
|
Resolved comments, the one main change is moving VAE sharding to config files. New run shows similar e2e time to above. WAN 2.1 8 TPU tring logs: gpaste. |
|
Found a perf diff caused by merging new commits: WAN 2.1 stays the same, WAN 2.2 has a ~15 sec increase in e2e time. It was caused by adding Now the inference time is back to previous ~200 sec: |
In this PR, we integrated tokamax ring attention kernels for WAN models. Below are the main changes made:
Added ring attention kernel and splash attention kernel under
src/maxdiffusion/kernels/splash_attention/. Here is the doc for the modification we made: Ring Attention Kernel Precision Issue. Modifiedattention_flax.pyto supporttokamax_ringJITted VAE and sharded VAE: added new config
vae_spatial(default to 1) to let users decide how to shard VAE.Xprof: modified profiler code to actually use
profiler_steps(for exampleprofiler_steps=3) instead of profiling the entire generation