"castonly/casttranspose HIP kernel optimization in fp8#519
"castonly/casttranspose HIP kernel optimization in fp8#519alextmagro wants to merge 14 commits intodevfrom
Conversation
5c5f6fe to
9435f7a
Compare
| do_general_config = true; | ||
| // even if we enforce to use OPTIMIZED_HIPIFIED_CAST_TRANSPOSE, may fall back to general kernel configs from NVTE cost model | ||
| bool nvte_use_optimized_hipified_cast_transpose = false; | ||
| if (const char* env_p = std::getenv("NVTE_USE_OPTIMIZED_HIPIFIED_CAST_TRANSPOSE") ) { |
There was a problem hiding this comment.
Are you proposing to remove the previous optimized hipified cast transpose PR completely? If so, is it because your new flow will always be faster than the previous one?
There was a problem hiding this comment.
Not only is my new flow always faster than the previous one, it appears that the upstream flow is always faster than the optimized hipified cast for both MI300 and MI350s anyway.
For real model shapes:
MI350 results: 1.73 TiB/s BW with upstream path, vs 1.17 TiB/s for hip only path
MI300 results: 1.48 TiB/s BW with upstream path, vs 0.90 TiB/s for hip only path
For reference, my path obtains 2.2 TiB/s on MI300 on average.
There was a problem hiding this comment.
Emm, did you try the configs listed in our previous PR #89. For example, I didn't see [2048, 12288] in your benchmark table. I didn't recall every details of that PR but I remember heyi's optimized flow is better than the baseline on MI300X and MI308X.
There was a problem hiding this comment.
The configs in PR89 was probably from MLPerf's previous submission. And https://amd.atlassian.net/wiki/spaces/~yewang12/pages/435389025/Cast+Transpose+Kernel+Optimization was from semianalysis. Better to try those as well
There was a problem hiding this comment.
Results for those configs have all been added to the issue.
| template <int LOAD_SIZE, int STORE_SIZE, int WARPS_PER_TILE, | ||
| typename IType, typename OType> | ||
| __global__ void __launch_bounds__(ROCM_CT_WARP_SIZE * WARPS_PER_TILE) | ||
| rocm_cast_transpose_kernel(const IType *__restrict__ input, |
There was a problem hiding this comment.
Besides the gfx950 specific optimization, how is this different from the current hiprtc kernel:
There was a problem hiding this comment.
The optimizations made that don't exist in the rtc kernel are listed at the bottom of the Claude comment discussing results -- let me know if those make sense or not.
There was a problem hiding this comment.
I didn't know a good tool to show you the logically duplicated parts so I got the following screenshots.
The left side is our current dev hiprtc cast_transpose kernel and the right side is your new rocm_cast_transpose.cuh.
Probably I missed your new optimization techniques? I believe those logically duplicated parts can be consolidated to save our future maintenance burden.
There was a problem hiding this comment.
While we share a some code with the rtc kernel, we have a few static optimizations that would make unifying things hard, like gfx950 intrinsics and NT stores that are type based.
I think we discussed this, but I am not sure there is much benefit to doing things rtc instead of templating other than keeping build times slightly shorter.
There was a problem hiding this comment.
Regarding maintenance, I agree that it is a bit extra work sometimes to keep things separate. However, for performance sensitive kernels, I think a pattern is coming out where upstream kernels perform poorly as is for us in most cases, and I think the tradeoff is worth it in cases like these.
There was a problem hiding this comment.
Oh, I'm not asking to choose just one implementation over the other, nor making your impl into hiprtc. I'm asking you whether you can consolidate those two similar/duplicated pieces of codes, in the pure coding style point of view. For example, put the common parts into some header, use macro of constexpr to add your specific things like gfx950 intrinsics and NT stores
| template <int LOAD_SIZE, int STORE_SIZE, int WARPS_PER_TILE, | ||
| typename IType, typename OType> | ||
| __global__ void __launch_bounds__(ROCM_CT_WARP_SIZE * WARPS_PER_TILE) | ||
| rocm_cast_transpose_kernel(const IType *__restrict__ input, |
There was a problem hiding this comment.
I didn't know a good tool to show you the logically duplicated parts so I got the following screenshots.
The left side is our current dev hiprtc cast_transpose kernel and the right side is your new rocm_cast_transpose.cuh.
Probably I missed your new optimization techniques? I believe those logically duplicated parts can be consolidated to save our future maintenance burden.
|
By the way, can you make the PR title accurate and concise, for example, "cast transpose hipified kernel optimization in fp8 delay-scaling". Therefore it can be easily found by future readers and TE maintainers |
I have changed the title -- one difference, we do touch FP8 Current scaling with this change as well -- hoping to run some e2e benchmarks w/ and w/o this PR to see the overall effect on both options. Will sync with Sudharshan for that. |
| TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( | ||
| input.dtype(), InputType, | ||
| TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( | ||
| output.dtype(), OutputType, |
There was a problem hiding this comment.
It is in FP8 space but neither Input nor Output datatype switches are FP8
There was a problem hiding this comment.
This cast is for FP8 delayed scaling only, we can change the output switch to FP8. This also aligns us with VectorizedUnaryKernel which this function replaces
Improvements to cast_transpose and cast for FP8 delayed scaling
Introduced rocm specific cast and cast+transpose functions tuned for MI350s and MI300s
Results are in https://github.com/ROCm/frameworks-internal/issues/16001
Optimizations Applied
Cast+Transpose Kernel (
rocm_cast_transpose.cuh)local_t[j2][iter].val[i2]during the load phase, avoiding a separate transpose pass__builtin_nontemporal_storeconfirmed asglobal_store_* ... ntin assembly__builtin_amdgcn_cvt_scalef32_pk_fp8_f32with scale=1.0 and pre-multiply. Values clamped to FP8 range viafmed3fbefore the intrinsic to prevent Inf-to-NaN (the intrinsic's E8M0 scale param is a post-quantization scale, not a pre-multiply)word_select=false/truepack 4 FP8 values into one uint32load(),store(),nt_store()methods. Shared across cast, cast+transpose, and MXFP8 kernels viarocm_device_utils.cuhCast-Only Kernel (
rocm_cast.cuh)M*Nelements. No tiling, no cascade, single kernel launch for any shapereinterpret_cast<uint32_t*>. No intermediateconverted[]array, which preserves the NT store hint through the compilerCVec::nt_store()confirmed asglobal_store_dwordx4 ... ntin assembly (required eliminating the intermediate array to prevent the compiler from dropping the NT hint)fmed3fclampingcu_countblocks for FP32 and small BF16 tensors;cu_count*2for BF16 tensors >128M elements (crossover point determined empirically)