Posts tagged JAX

Optimize GPT Training: Enabling Mixed Precision Training in JAX using ROCm on AMD GPUs

This blog builds on the nanoGPT model we discussed in A Guide to Implementing and Training Generative Pre-trained Transformers (GPT) in JAX on AMD GPUs. Here we will show you how to incorporate mixed precision training to the JAX-implemented nanoGPT model we discussed in our previous blog.

Read more ...


Using statistical methods to reliably compare algorithm performance in large generative AI models with JAX Profiler on AMD GPUs

This blog provides a comprehensive guide on measuring and comparing the performance of various algorithms in a JAX-implemented generative AI model. Leveraging the JAX Profiler and statistical analysis, this blog demonstrates how to reliably evaluate key steps and compare algorithm performance on AMD GPUs.

Read more ...


A Guide to Implementing and Training Generative Pre-trained Transformers (GPT) in JAX on AMD GPUs

2 July, 2024 by Douglas Jia.

Read more ...


LLM distributed supervised fine-tuning with JAX

25 Jan, 2024 by Douglas Jia.

Read more ...