Skip to content

Change bias initialization from 'embed' to 'heads'#371

Open
csgoogle wants to merge 1 commit intomainfrom
fixbiassharding
Open

Change bias initialization from 'embed' to 'heads'#371
csgoogle wants to merge 1 commit intomainfrom
fixbiassharding

Conversation

@csgoogle
Copy link
Copy Markdown
Collaborator

@csgoogle csgoogle commented Apr 6, 2026

  • Fix incorrect logical partitioning axes for attention and feed-forward parameters in Flax/WAN/LTX2 modules.
  • Refactor flash-attention block-size selection into a helper and add unit tests

doc: https://docs.google.com/document/d/1absFkpQAMM3YaYWxO_FYeqzDpypYeDbPsJRAV86nFQ0/edit?usp=sharing&resourcekey=0-FOzOmM0UdfU1LcDd_7epvw

Results

Metric main fixbiassharding Δ
Compile time 1913.9s 1906.4s -7.5s
Inference time 1656.4s 1642.1s -14.3s (-0.9%)

Notes

  • No difference observed with tp=1 configs — improvement only surfaces when tensor parallelism is active, as the axis fixes reduce parameter all-gather overhead in MLP layers
  • Primary motivation for this change is correctness: incorrect sharding axes can cause OOM or numerical issues at other parallelism configs
  • Larger gains expected at tp=4 or tp=8 where parameter communication is a larger fraction of step time

Video Quality Comparison

Branch Video
main main.mp4
fixbiassharding fixbiassharding.mp4

PSNR/SSIM (frame-by-frame, 81 frames):

Metric Mean Min Max
PSNR 19.37 dB 18.83 20.17
SSIM 0.7884 0.7654 0.8043

Low PSNR/SSIM reflects floating point non-determinism from different sharding layouts across 50 denoising steps (bfloat16 + different collective patterns) — videos are visually identical.

Video and Xprof after fix:

https://console.cloud.google.com/storage/browser/sagarchapara/shardingfixes

@csgoogle csgoogle requested a review from entrpn as a code owner April 6, 2026 10:09
@github-actions
Copy link
Copy Markdown

github-actions bot commented Apr 6, 2026

@csgoogle csgoogle force-pushed the fixbiassharding branch 2 times, most recently from d822acb to 15af39f Compare April 13, 2026 10:41
entrpn
entrpn previously approved these changes Apr 14, 2026
@Perseus14
Copy link
Copy Markdown
Collaborator

Could you add more details and results on the new commits? @csgoogle

@csgoogle csgoogle force-pushed the fixbiassharding branch 5 times, most recently from 9780b17 to 7a6ab88 Compare April 15, 2026 14:59
Comment thread src/maxdiffusion/models/attention_flax.py
@csgoogle
Copy link
Copy Markdown
Collaborator Author

Could you add more details and results on the new commits? @csgoogle

done

@github-actions
Copy link
Copy Markdown

🤖 Hi @csgoogle, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@github-actions
Copy link
Copy Markdown

🤖 I'm sorry @csgoogle, but I was unable to process your request. Please see the logs for more details.

Copy link
Copy Markdown

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

This Pull Request provides critical fixes for logical partitioning axes across several models (Flax, LTX2, Wan) and refactors flash-attention block-size selection. These changes are essential for correctness and performance efficiency when using tensor parallelism, specifically by reducing parameter communication overhead.

🔍 General Feedback

  • Bug Fix: The refactoring of flash-attention block size selection into _select_flash_block_sizes correctly addresses a latent bug where the wrong axis (head dimension instead of sequence length) was used for rank-3 inputs.
  • Performance: Correcting the bias sharding axes for QKV and output projections, as well as MLP layers, will optimize communication patterns in multi-device TPU setups.
  • Testing: The addition of unit tests for the new flash-attention helper functions ensures the robustness of the block size selection logic.

Comment thread src/maxdiffusion/models/attention_flax.py
Comment thread src/maxdiffusion/models/ltx2/attention_ltx2.py
Comment thread src/maxdiffusion/models/attention_flax.py
Comment thread src/maxdiffusion/models/attention_flax.py
Comment thread src/maxdiffusion/models/wan/transformers/transformer_wan.py
Comment thread src/maxdiffusion/models/attention_flax.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants