Skip to content

Add Google Cloud ML Diagnostics and profiling support#377

Open
mbohlool wants to merge 1 commit intomainfrom
mldiag
Open

Add Google Cloud ML Diagnostics and profiling support#377
mbohlool wants to merge 1 commit intomainfrom
mldiag

Conversation

@mbohlool
Copy link
Copy Markdown
Collaborator

Description

This PR integrates Google Cloud ML Diagnostics and XProf profiling into MaxDiffusion. It provides users the ability to automatically profile and track performance metrics of their training and generation runs via the Cluster Director console. The ML Diagnostics package is kept as an optional dependency so it doesn't bloat the environment for users who don't need it.

Note: This PR only cover the profiling. Another PR will be sent to handle metrics.

Changes

  • Documentation: Added docs/profiling.md which explains manual installation, configuration, and troubleshooting for ML Diagnostics.
  • Configurations: Added the enable_ml_diagnostics, profiler_gcs_path, and enable_ondemand_xprof flags to all base configuration files (base*.yml, ltx*.yml, etc.).
  • Utils (max_utils.py):
    • Introduced a unified Profiler class (usable both as an object and a context manager) to abstract away the standard JAX profiler and the new ML Diagnostics xprof profiler.
    • Added ensure_machinelearning_job_runs to initialize the diagnostics job securely, safely bypassing execution if the optional google-cloud-mldiagnostics library is missing.
  • Trainers & Generation scripts:
    • Added initialization calls for machine learning runs at the start of training and generation scripts.
    • Updated BaseWanTrainer, DreamboothTrainer, FluxTrainer, StableDiffusionXLTrainer, and StableDiffusionTrainer to use the unified Profiler class.

Testing

  • (Author to fill in: Describe any tests performed to validate these changes, e.g., ran a short training step on a TPU pod to verify traces successfully upload to the Cloud bucket)

@mbohlool mbohlool requested a review from entrpn as a code owner April 14, 2026 18:38
@github-actions
Copy link
Copy Markdown

This commit introduces support for automated profiling and performance tracking
using Google Cloud ML Diagnostics. It enables on-demand xprof and integrates the
diagnostics package into both the generation and training scripts.

Changes:
- Added `docs/profiling.md` to guide users on enabling ML diagnostics.
- Added default ML Diagnostics config settings to base config files.
- Created `Profiler` class in `max_utils.py` to abstract JAX profiling and ML Diagnostics.
- Replaced `activate_profiler` / `deactivate_profiler` in trainers with the unified `Profiler` class.
- Added `ensure_machinelearning_job_runs` initialization to execution scripts.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants