# Using LoRA for efficient fine-tuning: Fundamental principles#

Low-Rank Adaptation of Large Language Models (LoRA) is used to address the challenges of fine-tuning large language models (LLMs). Models like GPT and Llama, which boast billions of parameters, are typically cost-prohibitive to fine-tune for specific tasks or domains. LoRA preserves pre-trained model weights and incorporates trainable layers within each model block. This results in a significant reduction in the number of parameters that need to be fine-tuned and considerably reduces GPU memory requirements. The key benefit of LoRA is that it substantially decreases the number of trainable parameters–sometimes by a factor of up to 10,000–leading to a considerable decrease in GPU resource demands.

## Why LoRA works#

Pre-trained LLMs have a low “intrinsic dimension” when they are adapted to a new task, which means that data can be effectively represented or approximated by a lower-dimensional space while retaining most of its essential information or structure. We can decompose the new weight matrix for the adapted task into lower-dimensional (smaller) matrices without losing a lot of important information. We achieve this by low-rank approximation.

The rank of a matrix is a value that gives you an idea of the matrix’s complexity. A low-rank approximation of a matrix aims to approximate the original matrix as closely as possible, but with a lower rank. A lower-rank matrix reduces computational complexity, and thus increases the efficiency of matrix multiplications. Low-rank decomposition refers to the process of effectively approximating matrix A by deriving low-rank approximations of A. Singular value decomposition (SVD) is a common method for low-rank decomposition.

Suppose `W`

represents the weight matrix in a given neural network layer and suppose `ΔW`

is the
weight update for `W`

after a full fine-tuning. We can then decompose the weight update matrix `ΔW`

into two smaller matrices: `ΔW = WA*WB`

, where `WA`

is an `A × r`

-dimensional matrix, and `WB`

is an
`r × B`

-dimensional matrix. Here, we keep the original weight `W`

frozen and only train the new
matrices `WA`

and `WB`

. This summarizes the LoRA method, which is also illustrated in the following
figure.

## The benefits of LoRA#

**Reduced resource consumption.**Fine-tuning deep learning models typically requires substantial computational resources, which can be expensive and time-consuming. LoRA reduces the demand for resources while maintaining high performance.**Faster iterations.**LoRA enables rapid iterations, making it easier to experiment with different fine-tuning tasks and adapt models quickly.**Improved transfer learning.**LoRA enhances the effectiveness of transfer learning, as models with LoRA adapters can be fine-tuned with fewer data. This is particularly valuable in situations where labeled data are scarce.**Broad applicability.**LoRA is versatile and can be applied across diverse domains, including natural language processing, computer vision, and speech recognition.**Lower carbon footprint.**By reducing computational requirements, LoRA contributes to a greener and more sustainable approach to deep learning.

## Train a neural network using the LoRA technique#

Our goal is to train a neural network for the classification of the MNIST database of handwritten digits. We then fine-tune this network to improve its performance for a category in which it doesn’t initially perform well.

Hardware: AMD Instinct GPU

Software:

tqdm

The code utilized in this blog post includes contributions sourced from LoRA implementation, with due credit attributed to Umar Jamil.

### Getting started#

Import the packages.

import torch import torchvision.datasets as datasets import torchvision.transforms as transforms import torch.nn as nn from tqdm import tqdm

Sets the seed for generating random numbers to make the model deterministic.

# Make torch deterministic _ = torch.manual_seed(0)

Load the data set.

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) # Load the MNIST data set mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform) # Create a dataloader for the training train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True) # Load the MNIST test set mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform) test_loader = torch.utils.data.DataLoader(mnist_testset, batch_size=10, shuffle=True) # Define the device device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

Create the neural network to classify the digits (we used code that makes it more complicated in order to better showcase LoRA).

# Create an overly expensive neural network to classify MNIST digits # Daddy got money, so I don't care about efficiency class RichBoyNet(nn.Module): def __init__(self, hidden_size_1=1000, hidden_size_2=2000): super(RichBoyNet,self).__init__() self.linear1 = nn.Linear(28*28, hidden_size_1) self.linear2 = nn.Linear(hidden_size_1, hidden_size_2) self.linear3 = nn.Linear(hidden_size_2, 10) self.relu = nn.ReLU() def forward(self, img): x = img.view(-1, 28*28) x = self.relu(self.linear1(x)) x = self.relu(self.linear2(x)) x = self.linear3(x) return x net = RichBoyNet().to(device)

Train the network for one epoch to simulate a complete general pre-training on the data. This process takes seconds on an AMD Instinct GPU.

def train(train_loader, net, epochs=5, total_iterations_limit=None): cross_el = nn.CrossEntropyLoss() optimizer = torch.optim.Adam(net.parameters(), lr=0.001) total_iterations = 0 for epoch in range(epochs): net.train() loss_sum = 0 num_iterations = 0 data_iterator = tqdm(train_loader, desc=f'Epoch {epoch+1}') if total_iterations_limit is not None: data_iterator.total = total_iterations_limit for data in data_iterator: num_iterations += 1 total_iterations += 1 x, y = data x = x.to(device) y = y.to(device) optimizer.zero_grad() output = net(x.view(-1, 28*28)) loss = cross_el(output, y) loss_sum += loss.item() avg_loss = loss_sum / num_iterations data_iterator.set_postfix(loss=avg_loss) loss.backward() optimizer.step() if total_iterations_limit is not None and total_iterations >= total_iterations_limit: return train(train_loader, net, epochs=1)

Epoch 1: 100%|██████████| 6000/6000 [00:20<00:00, 299.74it/s, loss=0.237]

[!TIP] Keep a copy of the original weights (clone them) in order to see that the original weights weren’t altered after fine-tuning.

original_weights = {} for name, param in net.named_parameters(): original_weights[name] = param.clone().detach()

### Fine-tuning#

Choose a digit to fine-tune. The pre-trained network performs poorly on digit 9, so we’ll fine-tune this.

def test(): correct = 0 total = 0 wrong_counts = [0 for i in range(10)] with torch.no_grad(): for data in tqdm(test_loader, desc='Testing'): x, y = data x = x.to(device) y = y.to(device) output = net(x.view(-1, 784)) for idx, i in enumerate(output): if torch.argmax(i) == y[idx]: correct +=1 else: wrong_counts[y[idx]] +=1 total +=1 print(f'Accuracy: {round(correct/total, 3)}') for i in range(len(wrong_counts)): print(f'wrong counts for the digit {i}: {wrong_counts[i]}') test()

Testing: 100%|██████████| 1000/1000 [00:02<00:00, 497.86it/s] Accuracy: 0.951 wrong counts for the digit 0: 35 wrong counts for the digit 1: 31 wrong counts for the digit 2: 26 wrong counts for the digit 3: 81 wrong counts for the digit 4: 34 wrong counts for the digit 5: 15 wrong counts for the digit 6: 74 wrong counts for the digit 7: 67 wrong counts for the digit 8: 11 wrong counts for the digit 9: 116

Visualize how many parameters are in the original network before introducing the LoRA matrices.

# Print the size of the weights matrices of the network # Save the count of the total number of parameters total_parameters_original = 0 for index, layer in enumerate([net.linear1, net.linear2, net.linear3]): total_parameters_original += layer.weight.nelement() + layer.bias.nelement() print(f'Layer {index+1}: W: {layer.weight.shape} + B: {layer.bias.shape}') print(f'Total number of parameters: {total_parameters_original:,}')

Layer 1: W: torch.Size([1000, 784]) + B: torch.Size([1000]) Layer 2: W: torch.Size([2000, 1000]) + B: torch.Size([2000]) Layer 3: W: torch.Size([10, 2000]) + B: torch.Size([10]) Total number of parameters: 2,807,010

Define the LoRA parameterization.

class LoRAParametrization(nn.Module): def __init__(self, features_in, features_out, rank=1, alpha=1, device='cpu'): super().__init__() # Section 4.1 of the paper: # We use a random Gaussian initialization for A and zero for B, so ∆W = BA is zero at the # beginning of training self.lora_A = nn.Parameter(torch.zeros((rank,features_out)).to(device)) self.lora_B = nn.Parameter(torch.zeros((features_in, rank)).to(device)) nn.init.normal_(self.lora_A, mean=0, std=1) # Section 4.1 of the paper: # We then scale ∆Wx by α/r , where α is a constant in r. # When optimizing with Adam, tuning α is roughly the same as tuning the learning rate if we # scale the initialization appropriately. # As a result, we simply set α to the first r we try and do not tune it. # This scaling helps to reduce the need to retune hyperparameters when we vary r. self.scale = alpha / rank self.enabled = True def forward(self, original_weights): if self.enabled: # Return W + (B*A)*scale return original_weights + torch.matmul(self.lora_B, self.lora_A).view(original_weights.shape) * self.scale else: return original_weights

Add the parameterization to our network. You can learn more about PyTorch parametrizations on PyTorch.org.

import torch.nn.utils.parametrize as parametrize def linear_layer_parameterization(layer, device, rank=1, lora_alpha=1): # Only add the parameterization to the weight matrix, ignore the Bias # From section 4.2 of the paper: # We limit our study to only adapting the attention weights for downstream tasks and freeze the # MLP modules (so they are not trained in downstream tasks) both for simplicity and # parameter-efficiency. # [...] # We leave the empirical investigation of [...], and biases to a future work. features_in, features_out = layer.weight.shape return LoRAParametrization( features_in, features_out, rank=rank, alpha=lora_alpha, device=device ) parametrize.register_parametrization( net.linear1, "weight", linear_layer_parameterization(net.linear1, device) ) parametrize.register_parametrization( net.linear2, "weight", linear_layer_parameterization(net.linear2, device) ) parametrize.register_parametrization( net.linear3, "weight", linear_layer_parameterization(net.linear3, device) ) def enable_disable_lora(enabled=True): for layer in [net.linear1, net.linear2, net.linear3]: layer.parametrizations["weight"][0].enabled = enabled

Display the number of parameters added by LoRA.

total_parameters_lora = 0 total_parameters_non_lora = 0 for index, layer in enumerate([net.linear1, net.linear2, net.linear3]): total_parameters_lora += layer.parametrizations["weight"][0].lora_A.nelement() + layer.parametrizations["weight"][0].lora_B.nelement() total_parameters_non_lora += layer.weight.nelement() + layer.bias.nelement() print( f'Layer {index+1}: W: {layer.weight.shape} + B: {layer.bias.shape} + Lora_A: {layer.parametrizations["weight"][0].lora_A.shape} + Lora_B: {layer.parametrizations["weight"][0].lora_B.shape}' ) # The non-LoRA parameters count must match the original network assert total_parameters_non_lora == total_parameters_original print(f'Total number of parameters (original): {total_parameters_non_lora:,}') print(f'Total number of parameters (original + LoRA): {total_parameters_lora + total_parameters_non_lora:,}') print(f'Parameters introduced by LoRA: {total_parameters_lora:,}') parameters_increment = (total_parameters_lora / total_parameters_non_lora) * 100 print(f'Parameters increment: {parameters_increment:.3f}%')

Layer 1: W: torch.Size([1000, 784]) + B: torch.Size([1000]) + Lora_A: torch.Size([1, 784]) + Lora_B: torch.Size([1000, 1]) Layer 2: W: torch.Size([2000, 1000]) + B: torch.Size([2000]) + Lora_A: torch.Size([1, 1000]) + Lora_B: torch.Size([2000, 1]) Layer 3: W: torch.Size([10, 2000]) + B: torch.Size([10]) + Lora_A: torch.Size([1, 2000]) + Lora_B: torch.Size([10, 1]) Total number of parameters (original): 2,807,010 Total number of parameters (original + LoRA): 2,813,804 Parameters introduced by LoRA: 6,794 Parameters increment: 0.242%

Freeze all the parameters of the original network and only fine-tune the ones introduced by LoRA. Then, fine-tune the model for digit 9 for 100 batches.

# Freeze the non-Lora parameters for name, param in net.named_parameters(): if 'lora' not in name: print(f'Freezing non-LoRA parameter {name}') param.requires_grad = False # Load the MNIST data set again, by keeping only the digit 9 mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform) exclude_indices = mnist_trainset.targets == 9 mnist_trainset.data = mnist_trainset.data[exclude_indices] mnist_trainset.targets = mnist_trainset.targets[exclude_indices] # Create a dataloader for the training train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True) # Train the network with LoRA only on the digit 9 and only for 100 batches (hoping that it would # improve the performance on the digit 9) train(train_loader, net, epochs=1, total_iterations_limit=100)

Freezing non-LoRA parameter linear1.bias Freezing non-LoRA parameter linear1.parametrizations.weight.original Freezing non-LoRA parameter linear2.bias Freezing non-LoRA parameter linear2.parametrizations.weight.original Freezing non-LoRA parameter linear3.bias Freezing non-LoRA parameter linear3.parametrizations.weight.original Epoch 1: 99%|█████████▉| 99/100 [00:00<00:00, 200.52it/s, loss=0.132]

Verify that the fine-tuning didn’t alter the original weights (using only those introduced by LoRA).

# Check that the frozen parameters are still unchanged by the finetuning assert torch.all(net.linear1.parametrizations.weight.original == original_weights['linear1.weight']) assert torch.all(net.linear2.parametrizations.weight.original == original_weights['linear2.weight']) assert torch.all(net.linear3.parametrizations.weight.original == original_weights['linear3.weight']) enable_disable_lora(enabled=True) # Now let's use layer of net.linear1 as an example to check if the Lora is applied to the model # correctly as defined in the LoRAParametrization.forward() # The new linear1.weight is obtained by the "forward" function of our LoRA parametrization # The original weights have been moved to net.linear1.parametrizations.weight.original # More info here: https://pytorch.org/tutorials/intermediate/parametrizations.html#inspecting-a-parametrized-module assert torch.equal(net.linear1.weight, net.linear1.parametrizations.weight.original + (net.linear1.parametrizations.weight[0].lora_B @ net.linear1.parametrizations.weight[0].lora_A) * net.linear1.parametrizations.weight[0].scale) enable_disable_lora(enabled=False) # If we disable LoRA, the linear1.weight is the original one assert torch.equal(net.linear1.weight, original_weights['linear1.weight'])

Test the network with LoRA enabled (the digit 9 should be classified better).

# Test with LoRA enabled enable_disable_lora(enabled=True) test()

Testing: 100%|██████████| 1000/1000 [00:02<00:00, 471.08it/s] Accuracy: 0.905 wrong counts for the digit 0: 144 wrong counts for the digit 1: 34 wrong counts for the digit 2: 30 wrong counts for the digit 3: 216 wrong counts for the digit 4: 161 wrong counts for the digit 5: 73 wrong counts for the digit 6: 93 wrong counts for the digit 7: 100 wrong counts for the digit 8: 95 wrong counts for the digit 9: 6

Test the network with LoRA disabled (the accuracy and errors counts must be the same as the original network).

# Test with LoRA disabled enable_disable_lora(enabled=False) test()

Testing: 100%|██████████| 1000/1000 [00:01<00:00, 517.04it/s] Accuracy: 0.951 wrong counts for the digit 0: 35 wrong counts for the digit 1: 31 wrong counts for the digit 2: 26 wrong counts for the digit 3: 81 wrong counts for the digit 4: 34 wrong counts for the digit 5: 15 wrong counts for the digit 6: 74 wrong counts for the digit 7: 67 wrong counts for the digit 8: 11 wrong counts for the digit 9: 116

[!NOTE] You might observe that fine-tuning has impacted the accuracies of other labels. This is expected, as our fine-tuning was exclusively focused on digit 9.