Posts tagged JAX
Optimize GPT Training: Enabling Mixed Precision Training in JAX using ROCm on AMD GPUs
- 06 September 2024
This blog builds on the nanoGPT model we discussed in A Guide to Implementing and Training Generative Pre-trained Transformers (GPT) in JAX on AMD GPUs. Here we will show you how to incorporate mixed precision training to the JAX-implemented nanoGPT model we discussed in our previous blog.
Using statistical methods to reliably compare algorithm performance in large generative AI models with JAX Profiler on AMD GPUs
- 22 July 2024
This blog provides a comprehensive guide on measuring and comparing the performance of various algorithms in a JAX-implemented generative AI model. Leveraging the JAX Profiler and statistical analysis, this blog demonstrates how to reliably evaluate key steps and compare algorithm performance on AMD GPUs.
A Guide to Implementing and Training Generative Pre-trained Transformers (GPT) in JAX on AMD GPUs
- 02 July 2024
2 July, 2024 by Douglas Jia.