Posts tagged JAX

Supercharging JAX with Triton Kernels on AMD GPUs

Ready to supercharge your deep learning applications on AMD GPUs? In this blog, we’ll show you how to develop a custom fused dropout activation kernel for matrices in Triton, seamlessly call it from JAX, and benchmark its performance with ROCm. This powerful combination will take your model’s performance to the next level.

Read more ...


Optimize GPT Training: Enabling Mixed Precision Training in JAX using ROCm on AMD GPUs

This blog builds on the nanoGPT model we discussed in A Guide to Implementing and Training Generative Pre-trained Transformers (GPT) in JAX on AMD GPUs. Here we will show you how to incorporate mixed precision training to the JAX-implemented nanoGPT model we discussed in our previous blog.

Read more ...


Using statistical methods to reliably compare algorithm performance in large generative AI models with JAX Profiler on AMD GPUs

This blog provides a comprehensive guide on measuring and comparing the performance of various algorithms in a JAX-implemented generative AI model. Leveraging the JAX Profiler and statistical analysis, this blog demonstrates how to reliably evaluate key steps and compare algorithm performance on AMD GPUs.

Read more ...


A Guide to Implementing and Training Generative Pre-trained Transformers (GPT) in JAX on AMD GPUs

2 July, 2024 by Douglas Jia.

Read more ...


LLM distributed supervised fine-tuning with JAX

25 Jan, 2024 by Douglas Jia.

Read more ...