Posts tagged JAX

Running SOTA AI-based Weather Forecasting models on AMD Instinct

Weather Forecasting is a complex scientific problem where immense progress has been made through the Numerical Weather Prediction (NWP) approach using computational fluid dynamics-based models. Forecasting is usually done in three stages: a data assimilation stage where all available data streams at the time \(t\) (sometimes previous times can be used to improve this estimate) are used to estimate the current 3D state of the atmosphere \( S_{t}\) (surface and atmosphere), as parameterized by a number of variables at the current time \( t\), a forecasting stage where the state \(\hat{S}_{t + \delta t}\) for a later time \( t+ \delta t\) (i.e., all the variables at this later time) are forecasted, and a downstream stage where the forecasted state at time \(t + \delta t\) is used to estimate weather variables at more specific times and locations.

Read more ...


ROCm 7.0: An AI-Ready Powerhouse for Performance, Efficiency, and Productivity

Artificial intelligence now defines the performance envelope for modern computation. In this blog, we introduce the AI-centric ROCm 7.0 designed to help our community directly benefit from this dramatic paradigm shift. ROCm 7.0 delivers a platform purpose-built for the era of generative AI, large-scale inference and training, and accelerated discovery, helping you boost the performance, efficiency, and scalability of your workloads.

Read more ...


Supercharging JAX with Triton Kernels on AMD GPUs

Ready to supercharge your deep learning applications on AMD GPUs? In this blog, we’ll show you how to develop a custom fused dropout activation kernel for matrices in Triton, seamlessly call it from JAX, and benchmark its performance with ROCm. This powerful combination will take your model’s performance to the next level.

Read more ...


Optimize GPT Training: Enabling Mixed Precision Training in JAX using ROCm on AMD GPUs

This blog builds on the nanoGPT model we discussed in A Guide to Implementing and Training Generative Pre-trained Transformers (GPT) in JAX on AMD GPUs. Here we will show you how to incorporate mixed precision training to the JAX-implemented nanoGPT model we discussed in our previous blog.

Read more ...


Using statistical methods to reliably compare algorithm performance in large generative AI models with JAX Profiler on AMD GPUs

This blog provides a comprehensive guide on measuring and comparing the performance of various algorithms in a JAX-implemented generative AI model. Leveraging the JAX Profiler and statistical analysis, this blog demonstrates how to reliably evaluate key steps and compare algorithm performance on AMD GPUs.

Read more ...


A Guide to Implementing and Training Generative Pre-trained Transformers (GPT) in JAX on AMD GPUs

2 July, 2024 by

.

Read more ...


LLM distributed supervised fine-tuning with JAX

25 Jan, 2024 by

.

Read more ...