Skip to content

"castonly/casttranspose HIP kernel optimization in fp8#519

Open
alextmagro wants to merge 14 commits intodevfrom
ct_opt
Open

"castonly/casttranspose HIP kernel optimization in fp8#519
alextmagro wants to merge 14 commits intodevfrom
ct_opt

Conversation

@alextmagro
Copy link
Copy Markdown
Contributor

@alextmagro alextmagro commented Apr 4, 2026

Improvements to cast_transpose and cast for FP8 delayed scaling

Introduced rocm specific cast and cast+transpose functions tuned for MI350s and MI300s

Results are in https://github.com/ROCm/frameworks-internal/issues/16001

Optimizations Applied

Cast+Transpose Kernel (rocm_cast_transpose.cuh)

  1. OVecT packed FP8 shared memory — smem stores CVec<fp8,8> instead of float, 4x smaller footprint, avoids bank conflicts with +1 padding
  2. Register transpose during load — accumulate transposed FP8 into local_t[j2][iter].val[i2] during the load phase, avoiding a separate transpose pass
  3. Non-temporal stores for output_c (rowwise) and output_t (transposed) — __builtin_nontemporal_store confirmed as global_store_* ... nt in assembly
  4. gfx950 FP8 packed intrinsics__builtin_amdgcn_cvt_scalef32_pk_fp8_f32 with scale=1.0 and pre-multiply. Values clamped to FP8 range via fmed3f before the intrinsic to prevent Inf-to-NaN (the intrinsic's E8M0 scale param is a post-quantization scale, not a pre-multiply)
  5. word_select packing — two intrinsic calls with word_select=false/true pack 4 FP8 values into one uint32
  6. Two-launch row strategy — STORE=8 for bulk, then best-fit single launch for remainder (STORE=4 if rem%128==0, STORE=2 if rem%64==0, general kernel otherwise). Max 2 launches for any M value
  7. Single-launch for small tensors — shapes with M<512 attempt a single kernel launch with the best-fit STORE size, avoiding multi-launch overhead
  8. Column cascade — LOAD_SZ checks for single-launch alignment, then cascades to smaller LOAD sizes for remainder columns
  9. CVec standalone vector type — aligned vector struct with load(), store(), nt_store() methods. Shared across cast, cast+transpose, and MXFP8 kernels via rocm_device_utils.cuh
  10. BF16/FP16 LOAD capped at 8 — LOAD=16 for BF16 uses 211 VGPRs (2 waves/SIMD). Capping at LOAD=8 uses 125 VGPRs (4 waves/SIMD), doubling occupancy

Cast-Only Kernel (rocm_cast.cuh)

  1. Dedicated 1D grid-stride kernel — flat 1D indexing over M*N elements. No tiling, no cascade, single kernel launch for any shape
  2. 256 threads/block, 16 elements/thread — 42 VGPRs (from assembly), 10 waves/SIMD max, 0 scratch, 32 bytes LDS (amax only)
  3. Direct FP8 packing into OVec — intrinsic results written directly into the output CVec via reinterpret_cast<uint32_t*>. No intermediate converted[] array, which preserves the NT store hint through the compiler
  4. Non-temporal storesCVec::nt_store() confirmed as global_store_dwordx4 ... nt in assembly (required eliminating the intermediate array to prevent the compiler from dropping the NT hint)
  5. gfx950 FP8 packed intrinsics — same as cast+transpose, with fmed3f clamping
  6. Dynamic grid sizingcu_count blocks for FP32 and small BF16 tensors; cu_count*2 for BF16 tensors >128M elements (crossover point determined empirically)
  7. Scalar tail for non-aligned element counts (rarely exercised — model dimensions are multiples of 16+)

@alextmagro alextmagro force-pushed the ct_opt branch 3 times, most recently from 5c5f6fe to 9435f7a Compare April 4, 2026 05:30
@alextmagro alextmagro marked this pull request as ready for review April 4, 2026 05:31
@alextmagro alextmagro added the ci-level 3 CI test level 3 label Apr 4, 2026
@alextmagro alextmagro requested a review from ipanfilo April 6, 2026 18:40
do_general_config = true;
// even if we enforce to use OPTIMIZED_HIPIFIED_CAST_TRANSPOSE, may fall back to general kernel configs from NVTE cost model
bool nvte_use_optimized_hipified_cast_transpose = false;
if (const char* env_p = std::getenv("NVTE_USE_OPTIMIZED_HIPIFIED_CAST_TRANSPOSE") ) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you proposing to remove the previous optimized hipified cast transpose PR completely? If so, is it because your new flow will always be faster than the previous one?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not only is my new flow always faster than the previous one, it appears that the upstream flow is always faster than the optimized hipified cast for both MI300 and MI350s anyway.

For real model shapes:
MI350 results: 1.73 TiB/s BW with upstream path, vs 1.17 TiB/s for hip only path

MI300 results: 1.48 TiB/s BW with upstream path, vs 0.90 TiB/s for hip only path

For reference, my path obtains 2.2 TiB/s on MI300 on average.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Emm, did you try the configs listed in our previous PR #89. For example, I didn't see [2048, 12288] in your benchmark table. I didn't recall every details of that PR but I remember heyi's optimized flow is better than the baseline on MI300X and MI308X.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The configs in PR89 was probably from MLPerf's previous submission. And https://amd.atlassian.net/wiki/spaces/~yewang12/pages/435389025/Cast+Transpose+Kernel+Optimization was from semianalysis. Better to try those as well

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Results for those configs have all been added to the issue.

template <int LOAD_SIZE, int STORE_SIZE, int WARPS_PER_TILE,
typename IType, typename OType>
__global__ void __launch_bounds__(ROCM_CT_WARP_SIZE * WARPS_PER_TILE)
rocm_cast_transpose_kernel(const IType *__restrict__ input,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Besides the gfx950 specific optimization, how is this different from the current hiprtc kernel:

__global__ void __launch_bounds__(block_size) cast_transpose_optimized_kernel(
const IType* __restrict__ const input, const CType* __restrict__ const noop,
OType* __restrict__ const output_c, OType* __restrict__ const output_t,
const CType* __restrict__ const scale_ptr, CType* __restrict__ const amax_ptr,
CType* __restrict__ const scale_inv_ptr, const size_t row_length, const size_t num_rows) {
if (noop != nullptr && noop[0] == 1.0f) return;
// Vectorized load/store sizes
constexpr size_t nvec_in = load_size / sizeof(IType);
constexpr size_t nvec_out = store_size / sizeof(OType);
using IVec = Vec<IType, nvec_in>;
using OVecC = Vec<OType, nvec_in>;
using OVecT = Vec<OType, nvec_out>;
// Thread indices
// Note: Block is interpreted as a warp_size x num_warps grid
constexpr size_t bdimx = THREADS_PER_WARP;
constexpr size_t bdimy = warps_per_tile;
const size_t tid = threadIdx.x;
const size_t tidx = tid % bdimx;
const size_t tidy = tid / bdimx;
const size_t bid = blockIdx.x;
// Input tensors are divided into tiles
// Note: Each tile is a warp_size x warp_size grid of nvec_out x nvec_in subtiles
constexpr size_t tile_dim_m = THREADS_PER_WARP * nvec_out;
constexpr size_t tile_dim_n = THREADS_PER_WARP * nvec_in;
// Position of tile within tensor
const size_t num_tiles_m = num_rows / tile_dim_m;
const size_t tile_id_m = bid % num_tiles_m;
const size_t tile_id_n = bid / num_tiles_m;
const size_t tile_row = tile_id_m * tile_dim_m;
const size_t tile_col = tile_id_n * tile_dim_n;
// Number of nvec_out x nvec_in subtiles for each thread to
// load/store
constexpr size_t num_iterations = THREADS_PER_WARP / warps_per_tile;
// FP8 factors
const CType scale = scale_ptr == nullptr ? 1 : *scale_ptr;
CType amax = 0;
// Load input to registers and transpose
// Note: Each thread loads num_iterations subtiles, computes amax,
// casts type, and transposes in registers.
OVecT local_output_t[nvec_in][num_iterations];
#pragma unroll
for (size_t iter = 0; iter < num_iterations; ++iter) {
const size_t i1 = tidy + iter * bdimy;
const size_t j1 = tidx;
#pragma unroll
for (size_t i2 = 0; i2 < nvec_out; ++i2) {
const size_t row = tile_row + i1 * nvec_out + i2;
const size_t col = tile_col + j1 * nvec_in;
IVec local_input;
OVecC local_output_c;
local_input.load_from(&input[row * row_length + col]);
#pragma unroll
for (size_t j2 = 0; j2 < nvec_in; ++j2) {
const CType in = static_cast<CType>(local_input.data.elt[j2]);
const OType out = OType(in * scale);
__builtin_assume(amax >= 0);
amax = fmaxf(fabsf(in), amax);
local_output_c.data.elt[j2] = out;
local_output_t[j2][iter].data.elt[i2] = out;
}
local_output_c.store_to(&output_c[row * row_length + col]);
}
}
// Copy from registers to shared memory to global memory
__shared__ OVecT shared_output_t[THREADS_PER_WARP][THREADS_PER_WARP + 1];
#pragma unroll
for (size_t j2 = 0; j2 < nvec_in; ++j2) {
#pragma unroll
for (size_t iter = 0; iter < num_iterations; ++iter) {
const size_t i1 = tidy + iter * bdimy;
const size_t j1 = tidx;
shared_output_t[j1][i1] = local_output_t[j2][iter];
}
__syncthreads();
#pragma unroll
for (size_t iter = 0; iter < num_iterations; ++iter) {
const size_t i1 = tidx;
const size_t j1 = tidy + iter * bdimy;
const size_t row = tile_row + i1 * nvec_out;
const size_t col = tile_col + j1 * nvec_in + j2;
shared_output_t[j1][i1].store_to(&output_t[col * num_rows + row]);
}
__syncthreads();
}
// Reduce amax over block
if (amax_ptr != nullptr) {
amax = reduce_max<warps_per_tile>(amax, tidy);
if (threadIdx.x == 0) {
atomicMaxFloat(amax_ptr, amax);
}
}
// Update scale-inverse
if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv_ptr != nullptr) {
reciprocal<CType>(scale_inv_ptr, scale);
}

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The optimizations made that don't exist in the rtc kernel are listed at the bottom of the Claude comment discussing results -- let me know if those make sense or not.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't know a good tool to show you the logically duplicated parts so I got the following screenshots.
The left side is our current dev hiprtc cast_transpose kernel and the right side is your new rocm_cast_transpose.cuh.

Probably I missed your new optimization techniques? I believe those logically duplicated parts can be consolidated to save our future maintenance burden.

image image image

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While we share a some code with the rtc kernel, we have a few static optimizations that would make unifying things hard, like gfx950 intrinsics and NT stores that are type based.

I think we discussed this, but I am not sure there is much benefit to doing things rtc instead of templating other than keeping build times slightly shorter.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding maintenance, I agree that it is a bit extra work sometimes to keep things separate. However, for performance sensitive kernels, I think a pattern is coming out where upstream kernels perform poorly as is for us in most cases, and I think the tradeoff is worth it in cases like these.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I'm not asking to choose just one implementation over the other, nor making your impl into hiprtc. I'm asking you whether you can consolidate those two similar/duplicated pieces of codes, in the pure coding style point of view. For example, put the common parts into some header, use macro of constexpr to add your specific things like gfx950 intrinsics and NT stores

@alextmagro alextmagro requested review from aris134 and wangye805 April 7, 2026 18:32
template <int LOAD_SIZE, int STORE_SIZE, int WARPS_PER_TILE,
typename IType, typename OType>
__global__ void __launch_bounds__(ROCM_CT_WARP_SIZE * WARPS_PER_TILE)
rocm_cast_transpose_kernel(const IType *__restrict__ input,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't know a good tool to show you the logically duplicated parts so I got the following screenshots.
The left side is our current dev hiprtc cast_transpose kernel and the right side is your new rocm_cast_transpose.cuh.

Probably I missed your new optimization techniques? I believe those logically duplicated parts can be consolidated to save our future maintenance burden.

image image image

@wangye805
Copy link
Copy Markdown
Collaborator

By the way, can you make the PR title accurate and concise, for example, "cast transpose hipified kernel optimization in fp8 delay-scaling". Therefore it can be easily found by future readers and TE maintainers

@alextmagro alextmagro requested review from ipanfilo and wangye805 April 9, 2026 19:45
@alextmagro alextmagro requested a review from wangye805 April 11, 2026 14:13
@alextmagro alextmagro changed the title Ct opt "castonly/casttranspose HIP kernel optimization in fp8 Apr 11, 2026
@alextmagro
Copy link
Copy Markdown
Contributor Author

alextmagro commented Apr 11, 2026

By the way, can you make the PR title accurate and concise, for example, "cast transpose hipified kernel optimization in fp8 delay-scaling". Therefore it can be easily found by future readers and TE maintainers

I have changed the title -- one difference, we do touch FP8 Current scaling with this change as well -- hoping to run some e2e benchmarks w/ and w/o this PR to see the overall effect on both options. Will sync with Sudharshan for that.

TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.dtype(), InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
output.dtype(), OutputType,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is in FP8 space but neither Input nor Output datatype switches are FP8

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This cast is for FP8 delayed scaling only, we can change the output switch to FP8. This also aligns us with VectorizedUnaryKernel which this function replaces

@alextmagro alextmagro requested a review from ipanfilo April 14, 2026 17:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-level 3 CI test level 3

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants