From 6b72f952fe7710fd16dc36372f5823e4bc16831d Mon Sep 17 00:00:00 2001 From: Prisha Jain Date: Wed, 15 Apr 2026 17:19:55 +0530 Subject: [PATCH] upsampler fix --- src/maxdiffusion/models/ltx2/attention_ltx2.py | 8 ++++---- src/maxdiffusion/models/ltx2/latent_upsampler_ltx2.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/maxdiffusion/models/ltx2/attention_ltx2.py b/src/maxdiffusion/models/ltx2/attention_ltx2.py index 7441a2038..398b0f473 100644 --- a/src/maxdiffusion/models/ltx2/attention_ltx2.py +++ b/src/maxdiffusion/models/ltx2/attention_ltx2.py @@ -359,13 +359,13 @@ def __init__( # 1. Define Partitioned Initializers (Logical Axes) # Q, K, V kernels: [in_features (embed), out_features (heads)] qkv_kernel_init = nnx.with_partitioning(nnx.initializers.lecun_normal(), ("embed", "heads")) - # Q, K, V biases: [out_features (embed)] - qkv_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), ("embed",)) + # Q, K, V biases: [out_features (heads)] + qkv_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), ("heads",)) # Out kernel: [in_features (heads), out_features (embed)] out_kernel_init = nnx.with_partitioning(nnx.initializers.lecun_normal(), ("heads", "embed")) - # Out bias: [out_features (heads)] - out_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), ("heads",)) + # Out bias: [out_features (embed)] + out_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), ("embed",)) # Norm scales norm_scale_init = nnx.with_partitioning(nnx.initializers.ones_init(), ("norm",)) diff --git a/src/maxdiffusion/models/ltx2/latent_upsampler_ltx2.py b/src/maxdiffusion/models/ltx2/latent_upsampler_ltx2.py index 1b43457af..20436f42f 100644 --- a/src/maxdiffusion/models/ltx2/latent_upsampler_ltx2.py +++ b/src/maxdiffusion/models/ltx2/latent_upsampler_ltx2.py @@ -165,12 +165,12 @@ def __init__(self, in_channels: int, mid_channels: int = 1024, scale: float = 2. in_channels, (num**2) * self.mid_channels, kernel_size=(3, 3), padding=((1, 1), (1, 1)), rngs=rngs ) self.pixel_shuffle = PixelShuffleND(dims=2, upscale_factors=(num, num)) - self.blur = BlurDownsample(dims=2, stride=den) + self.blur_down = BlurDownsample(dims=2, stride=den) def __call__(self, x: jax.Array) -> jax.Array: x = self.conv(x) x = self.pixel_shuffle(x) - x = self.blur(x) + x = self.blur_down(x) return x