You have successfully downloaded the Llama 3 8B weights. You set up your PyTorch environment, configured a basic LoRA adapter, and launched the training script on your RTX 3090 or Google Colab T4.
Then, before the first epoch even starts, you hit the wall: RuntimeError: CUDA out of memory. Tried to allocate...
This is the single most common barrier for engineers moving from using LLMs to fine-tuning them. It is frustrating because the math seems like it should work. If Llama 3 8B is roughly 15GB in half-precision (FP16), why does it crash a 24GB or even 40GB card?
The answer lies in the hidden memory overhead of the training process itself. This guide provides a root cause analysis of VRAM consumption and a production-grade code solution using QLoRA (Quantized LoRA) to fit Llama 3 training pipelines onto consumer hardware.
The Root Cause: Where Did the VRAM Go?
To fix the OOM (Out Of Memory) error, you must understand what consumes GPU memory. It is not just the model weights.
When you load Llama 3 8B in standard FP16 (16-bit floating point), the Model Weights occupy approximately 16GB of VRAM ($8 \times 10^9 \text{ parameters} \times 2 \text{ bytes}$).
On a 16GB T4 GPU, you are effectively at capacity just by loading the model. However, training requires three additional memory buckets:
- Optimizer States: Standard AdamW maintains two states (momentum and variance) per parameter. In FP32, this consumes 8 bytes per parameter. For an 8B model, that is an additional 64GB of VRAM. Even 8-bit optimizers consume substantial space.
- Gradients: You need to store the gradients for every trainable parameter.
- Activations: This is the silent killer. During the forward pass, PyTorch stores intermediate outputs of every layer to calculate gradients during backpropagation. The longer your context length (sequence length), the quadratically larger this grows.
Even with PEFT (Parameter-Efficient Fine-Tuning) techniques like LoRA, loading the base model in 16-bit precision creates a baseline memory footprint that is too large for consumer cards.
The Solution: 4-bit Quantization and Gradient Checkpointing
To fine-tune Llama 3 on a T4 (16GB) or RTX 3090 (24GB), we must implement a specific stack of optimizations:
- 4-bit Quantization (QLoRA): We load the base model in 4-bit precision (NF4 format). This compresses the 16GB base model down to roughly 5.5GB.
- Paged Optimizers: We use
paged_adamw_8bit. If the GPU runs out of memory, this optimizer offloads optimizer states to the CPU RAM automatically, preventing the crash. - Gradient Checkpointing: Instead of storing all activations, we throw some away and recompute them during the backward pass. This trades a small amount of compute speed for massive memory savings.
The Implementation
Below is a complete, drop-in Python script using transformers, peft, and bitsandbytes.
Prerequisites
Ensure you have the latest versions of the required libraries. Older versions of transformers or accelerate may not support Llama 3 correctly.
pip install -U torch torchvision torchaudio
pip install -U transformers peft bitsandbytes trl accelerate
The Training Script
This code configures the specific quantization and LoRA parameters required to stabilize memory usage.
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
TrainingArguments,
)
from peft import (
LoraConfig,
prepare_model_for_kbit_training,
get_peft_model,
)
from trl import SFTTrainer
from datasets import load_dataset
# 1. Configuration
MODEL_NAME = "meta-llama/Meta-Llama-3-8B"
NEW_MODEL_NAME = "Llama-3-8B-FineTuned"
# 2. 4-Bit Quantization Configuration (The Memory Saver)
# This reduces model weight size by ~75%
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4", # Normalized Float 4 (optimized for LLMs)
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
)
# 3. Load Base Model
# device_map="auto" allows Accelerate to handle device placement
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True
)
# 4. Load Tokenizer & Fix Padding
# Llama 3 does not have a default pad token, causing training errors
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right" # Fix weird overflow issues with fp16
# 5. Prepare Model for QLoRA
# This enables gradient checkpointing and prepares layers for 4-bit training
model = prepare_model_for_kbit_training(model)
# 6. LoRA Configuration
# rank (r): Lower = less VRAM, higher = better learning (usually).
# 64 is a standard balance.
peft_config = LoraConfig(
r=16,
lora_alpha=16,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
)
model = get_peft_model(model, peft_config)
# 7. Training Arguments
# 'paged_adamw_8bit' is critical for avoiding OOM spikes
training_args = TrainingArguments(
output_dir="./results",
num_train_epochs=1,
per_device_train_batch_size=4, # Keep low for T4 (4 is safe, 2 is safer)
gradient_accumulation_steps=1,
optim="paged_adamw_8bit",
save_steps=25,
logging_steps=25,
learning_rate=2e-4,
weight_decay=0.001,
fp16=True, # Use mixed precision
bf16=False, # Change to True if using Ampere (RTX 3090/4090)
max_grad_norm=0.3,
warmup_ratio=0.03,
group_by_length=True,
lr_scheduler_type="constant",
report_to="none" # Disable wandb for this snippet
)
# 8. Dummy Dataset for Demonstration
# Replace this with your specific dataset loading logic
dataset = load_dataset("mlabonne/guanaco-llama2-1k", split="train")
# 9. Initialize Trainer
trainer = SFTTrainer(
model=model,
train_dataset=dataset,
peft_config=peft_config,
dataset_text_field="text",
max_seq_length=512, # Lower this if you still OOM (e.g., to 256)
tokenizer=tokenizer,
args=training_args,
packing=False,
)
# 10. Train
print("Starting training...")
trainer.train()
# 11. Save Adapter
trainer.model.save_pretrained(NEW_MODEL_NAME)
print(f"Model saved to {NEW_MODEL_NAME}")
Deep Dive: Why This Fix Works
NF4 Quantization vs. FP4
In the BitsAndBytesConfig, we selected nf4. Standard 4-bit float types divide the numeric range evenly. However, neural network weights usually follow a normal distribution (bell curve). NF4 (Normal Float 4) is a data type specifically designed for this distribution. It offers higher fidelity representation of the weights than standard floats, meaning we lose less accuracy despite the massive compression.
Paged AdamW
The optimizer paged_adamw_8bit leverages a feature in NVIDIA CUDA called Unified Memory. When the GPU VRAM hits 100% capacity due to a spike in optimizer states, this optimizer automatically pages memory blocks to the system RAM (CPU memory). It acts as a safety net, preventing the hard crash that usually terminates training.
Target Modules
In the LoraConfig, we targeted all linear layers (q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj). While targeting only q_proj and v_proj saves more memory, research indicates that targeting all linear layers yields significantly better model performance. With 4-bit quantization, we have enough VRAM headroom to afford targeting all modules.
Common Pitfalls and Edge Cases
1. The "Pad Token" Error
Llama 3's tokenizer is different from Llama 2. It does not have a designated padding token by default. If you attempt to train without tokenizer.pad_token = tokenizer.eos_token, the loss calculation may fail or result in NaN because the model cannot properly mask padding tokens in the batch.
2. Flash Attention Compatibility
If you are using an RTX 3090, 4090, or A100, you should enable Flash Attention 2 for faster training.
model = AutoModelForCausalLM.from_pretrained(
...,
attn_implementation="flash_attention_2"
)
Warning: Do not enable this on a T4 or V100; these older cards do not support Flash Attention 2, and it will throw an error.
3. Artifacting During Inference
After training with QLoRA, you have an adapter (small file) and a quantized base model. If you merge them for inference, ensure you unquantize the base model to FP16 first, merge the LoRA weights, and then save. Merging directly into 4-bit weights usually degrades performance significantly.
Conclusion
Fine-tuning Llama 3 on consumer hardware is a game of memory management. By aggressively quantizing the base model to 4-bit and utilizing paged optimizers, you can fit an 8B parameter model into the 16GB VRAM constraint of a standard Google Colab T4 or a local mid-range GPU.
This approach allows you to move past infrastructure debugging and focus on what actually matters: dataset quality and model evaluation.