Llama3 405B Tuning Experience on AMD MI300x

Introduction

Open source models are becoming increasingly large, so the need for a reliable infrastructure to perform large-scale AI training is greater than ever. Our company recently fine-tuned the model LLaMA 3.1 405B on AMD GPUs, proving their ability to efficiently handle large-scale AI tasks. Our experience has been extremely positive, and we are happy to publish all of our work on GitHub in open source.

AMD GPUs, and especially the MI300X series, are a serious alternative to NVIDIA AI hardware, delivering more performance per dollar. Our system consisted of a single node with 8 GPU AMD MI300xand for fine-tuning we used JAXIn this article we will tell the whole story of fine-tuning LLaMA 405B, including details of parameter sharding and LoRA implementation.

What is JAX and why did we choose it?

JAX is a powerful machine learning library that combines NumPy-like APIs, automatic differentiation, and Google's XLA compiler. It has great APIs for model parallelism, ideal for training huge models like LLaMA 3.1 405B.

Why I love JAX so much:

  1. Pure functions: JAX encourages you to write pure functions (if you want to JIT compile your code), which makes it easier to build, debug, and read your code.

  2. Advanced Parallelism: The JAX library's flexible JIT APIs natively support advanced data and model parallelism, which is critical for large-scale training.

  3. Improving the cleanliness of code bases: The JAX design philosophy encourages writing code that is natively portable across hardware platforms (CPU, GPU, TPU), resulting in cleaner and more maintainable codebases.

If you want to dive deeper into the benefits of JAX over PyTorch, I recommend reading this post PyTorch is dead. Long live JAX.

JAX is especially great when working with non-NVIDIA hardware:

When working with AMD, JAX provides many benefits:

  1. Equipment independent approach: JAX uses the XLA (Accelerated Linear Algebra) compiler, which compiles computations into a hardware-independent intermediate representation (HLO graph). This allows the same JAX code to be optimized and efficiently executed without modification on different hardware backends, including AMD GPUs.

  2. Platform independent optimizations: The XLA compiler performs hardware-independent optimizations, benefiting all supported platforms.

  3. Simplified portability: When working with JAX, switching from NVIDIA to AMD (or other supported hardware) requires only minimal code changes. This is in contrast to PyTorch, which is more closely tied to NVIDIA's CUDA ecosystem.

    • PyTorch often uses CUDA-specific implementations (e.g. calls torch.cuda, scaled_dot_product_attention).

    • While PyTorch supports other backends like ROCm for AMD GPUs, porting code can be challenging due to NVIDIA-specific code execution paths.

    • The process of “NVIDIA-freeing” PyTorch code can increase complexity and hinder portability.

Getting JAX ready for AMD is super easy!

Setting up JAX on AMD GPUs is a very simple process:

# Подтягиваем образ Docker:
docker pull rocm/jax:latest

# Запускаем контейнер Docker:
docker run -it -w /workspace --device=/dev/kfd --device=/dev/dri --group-add video \
--cap-add=SYS_PTRACE --security-opt seccomp=unconfined --shm-size 16G rocm/jax:latest

# Верифицируем установку:
python3 -c 'import jax; print(jax.devices())'

I was running an AMD node consisting of 8 AMD MI300x GPUs. Each MI300x had 192GB of HBM3 memory. They compare extremely well to the new NVIDIA H100 GPUs. (See comparison below, source: TensorWave)

LLaMA 405B Training: Performance and Scalability

Using JAX I was able to train the model LLAMA 405B on AMD GPUs, achieving impressive results.

We have performed fine-tuning of LoRA with all model weights and lora parameters with accuracy bfloat16with LoRA rank = 8 and LoRA alpha = 16:

  • Model size: LLaMA model weights take up approximately 800 GB VRAM.

  • LoRA weights + optimizer state: approximately 400 GB VRAM.

  • Total VRAM Usage: 77% of total VRAM, approximately 1200 GB.

  • Restrictions: Due to the large size of the 405B model, the space for batch sizes and sequence lengths was limited. I used a batch size of 16 and a sequence length of 64.

  • JIT compilation: Also, due to space constraints I couldn't run the JIT compiled version; it probably requires a bit more space than the eager mode graph.

  • Learning speed: about 35 tokens per second in JAX eager mode (1 training stage took 30 s)

  • Memory efficiency: stable around 70%

  • Scaling: When running JAX, scaling was roughly linear across all 8 GPUs.

Below are the GPU metrics, memory efficiency and results rocm-smi for 8 GPUs in one training stage of the fine-tuning run:

results rocm-smi:

Device

Temperature

Power

Sections

Cooler

Performance

PwrCap

VRAM%

GPU%

0

58.0°C

232.0 W

NPS1,SPX,0

0%

auto

750.0 W

77%

27%

1

58.0°C

233.0 W

NPS1,SPX,0

0%

auto

750.0 W

77%

25%

2

56.0°C

236.0 W

NPS1,SPX,0

0%

auto

750.0 W

77%

24%

3

52.0°C

228.0 W

NPS1,SPX,0

0%

auto

750.0 W

77%

23%

4

59.0°C

232.0 W

NPS1,SPX,0

0%

auto

750.0 W

77%

22%

5

51.0°C

230.0 W

NPS1,SPX,0

0%

auto

750.0 W

77%

21%

6

61.0°C

235.0W

NPS1,SPX,0

0%

auto

750.0 W

77%

18%

7

56.0°C

227.0 W

NPS1,SPX,0

0%

auto

750.0 W

77%

18%

Full details on GPU usage, VRAM and rocm-smi data can be found in our Github repositories.

Our system for learning

We ported the LLaMA 3.1 architecture from PyTorch to JAX. Our implementation can be found in GitHub repositories.

This migration opened up new opportunities for us in terms of performance and scalability.

Loading the model and sharding parameters

Working with a model as huge as the LLaMA 405B requires efficient sharding of parameters across multiple devices. Below we will describe how we achieved this using JAX.

Sharding Options in JAX

To efficiently distribute the huge LLaMA 405B model across 8 AMD GPUs, we used JAX's device mesh feature (codepointer). A device mesh organizes the available devices into a multidimensional grid, allowing us to specify how computation and data are partitioned. In our system, we created a mesh of shape (1, 8, 1), specifically with data parallelism (dp), fully sharded data parallelism (fsdp), and model parallelism (mp) axes. We then applied specific sharding rules to the model parameters, specifying for each model tensor how its dimensions are partitioned across the mesh axes.

DEVICES = jax.devices()
DEVICE_COUNT = len(DEVICES)
DEVICE_MESH = mesh_utils.create_device_mesh((1, 8, 1))
MESH = Mesh(devices=DEVICE_MESH, axis_names=("dp", "fsdp", "mp"))

Visualization of sharding

Array sharding can be visualized using jax.debug.visualize_array_sharding. This is incredibly useful for checking that sharding specifications are being applied correctly.

Partition rules

We determined partitioning rules for different components of the model:

Method of sharding parameters

Applying sharding restrictions

During the model loading process, we incrementally shard the model weights using special sharding functions:

def make_shard_and_gather_fns(partition_specs):
    def make_shard_fn(partition_spec):
        out_sharding = NamedSharding(mesh, partition_spec)
        def shard_fn(tensor):
            return jax.device_put(tensor, out_sharding).block_until_ready()
        return shard_fn

    shard_fns = jax.tree_util.tree_map(make_shard_fn, partition_specs)
    return shard_fns

# Создаём функции шардинга на основании правил разбиения
shard_fns = make_shard_and_gather_fns(partitioning_rules)

This allows us to place each parameter on the appropriate devices with the specified sharding.

Sharding the training batch

Initially, the training batch is created in the usual way. Before passing it to the model, we shard it between GPUs according to the following code:

train_batch = jax.device_put(
    train_batch, NamedSharding(self.mesh, PS("dp", "fsdp"))
)

Here we specify that the training batch should be sharded between the data parallel axes ("dp") and fully sharded data parallel ("fsdp"), which in our case correspond to 1, 8; this leads to the following visualization:

Implementation of LoRA training

LoRA (Low-Rank Adaptation) reduces the number of parameters to learn by splitting weight updates into low-rank matrices. This is especially useful for fine-tuning large models.

Key aspects of our LoRA implementation:

  1. Separate parameterization: We store LoRA parameters (lora_a and lora_b) separately from the main model parameters.

  2. Stop Gradient: We use jax.lax.stop_gradient(kernel) to prevent the main model weights from being updated.

  3. Efficient matrix multiplication: We use lax.dot_general for fast matrix operations with precision control.

  4. Scaling factor: Before adding to the main output, the LoRA output is scaled by (self.lora_alpha / self.lora_rank).

LoRADense layer

We have implemented a special layer LoRADensewhich includes LoRA parameters:

class LoRADense(nn.Module):
    features: int
    lora_rank: int = 8
    lora_alpha: float = 16.0

    @nn.compact
    def __call__(self, inputs: Any) -> Any:
        # Параметр исходного ядра (заморожен)
        kernel = self.param('kernel', ...)
        y = lax.dot_general(inputs, jax.lax.stop_gradient(kernel), ...)

        # Параметры LoRA (обучаемые)
        lora_a = self.variable('lora_params', 'lora_a', ..., ...)
        lora_b = self.variable('lora_params', 'lora_b', ..., ...)

        # Вычисление выходных данных LoRA
        lora_output = lax.dot_general(inputs, lora_a.value, ...)
        lora_output = lax.dot_general(lora_output, lora_b.value, ...)

        # Комбинирование исходных выходных данных с модификациями LoRA
        y += (self.lora_alpha / self.lora_rank) * lora_output

        return y.astype(self.dtype)

LoRA Parameter Sharding

For efficient distribution of parameters LoRA We applied special sharding rules between devices using JAX. This ensures that the LoRA parameters are aligned with the sharding of the main model parameters, optimizing both memory usage and computational efficiency.

LoRA A matrices (lora_a)

  • Used by us partition specification: PS("fsdp", "mp").

  • Visualization:

    • Axis sharding: sharding of lora_a parameters between axes will be performed as (8, 1), that is, the first axis is divided by sharding into 8 devices (axis fsdp), and the second axis is not broken.

      The illustration shows that the first axis is divided by sharding into 8 devices (axis fsdp), and the second axis is not broken.

LoRA B matrices (lora_b)

  • Used by us partition specification: PS("mp", "fsdp").

  • Visualization:

    • Axis sharding: sharding of lora_b parameters by layers will be performed as (1, 8), that is, the second axis is divided by sharding into 8 devices (axis fsdp), and the first axis is not broken.

      The illustration shows that the second axis is divided by sharding into 8 devices (axis fsdp), dividing the columns of the matrix.

This sharding strategy optimizes parameter distribution, reduces communication overhead, and increases training parallelism. It ensures that each device only contains a fraction of the LoRA parameters, enabling efficient scaling for large models like the LLaMA 405B.

Updating LoRA parameters only

To optimize training when fine-tuning the LLaMA 405B model, we compute gradients only for the LoRA parameters, leaving the main model parameters frozen. This approach reduces the amount of memory used and speeds up training because we update fewer parameters. Details of the implementation can be found in our GitHub repositories.

Our training loop involves passing a batch of input data through the model at each step. Since only the LoRA parameters are trained, the model predictions and the computed loss function depend only on these parameters. We then backpropagate the gradients with the LoRA parameters. By focusing the updates only on these parameters, we simplify the training process, allowing extremely large models like LLaMA 405B to be fine-tuned efficiently on multiple GPUs.

Conclusion

Fine-tuning of a huge model LLaMA 3.1 405B on AMD GPU using JAX left us with an extremely positive impression. By leveraging the powerful parallelism capabilities of JAX and its hardware-independent techniques, I was able to efficiently distribute the model across 8 AMD MI300x GPUs. Using parameter sharding, I was able to efficiently manage the huge volume of model parameters across devices, resulting in near-linear scalability and high memory efficiency.

This experience highlights the capabilities of AMD GPUs as a powerful alternative to NVIDIA hardware for large-scale AI training. The seamless integration of JAX with ROCm support simplifies the transition and opens up new possibilities for the AI ​​research and development community. By sharing my experience and code, I hope to motivate others to explore and apply these tools in their own large-scale machine learning projects.

Similar Posts

Leave a Reply

Your email address will not be published. Required fields are marked *