Fine-tuning is the process of adapting a pre-trained model to perform better on specific tasks by training it on domain-specific data. Learn how to fine-tune Jamba models using different approaches including full fine-tuning, LoRA, and QLoRA
The Jamba models can be fine-tuned using several approaches:
Full fine-tuning updates all model parameters and provides the most comprehensive training results.
For a comprehensive implementation guide using AWS SageMaker with multi-node and FSDP configuration, see the AI21 SageMaker Fine-tuning Repository.
Full fine-tuning requires multiple high-memory GPUs.
LoRA (Low-Rank Adaptation) fine-tuning injects compact, low-rank adapter layers into a frozen pretrained model—letting you specialize it for your task with just a few percent of the parameters, minimal extra compute and storage and with a small loss in accuracy or inference speed.
Before starting LoRA fine-tuning, install the required dependencies:
This LoRA fine-tuning example uses bfloat16 precision and requires ~130GB GPU RAM (e.g., 2x A100 80GB GPUs).
Load Model and Tokenizer
Configure LoRA Parameters
Prepare Your Dataset
Configure Training Settings
Initialize and Start Training
The dataset in this example uses conversational format (with messages
column), so SFTTrainer
automatically applies Jamba’s chat template. For more information about supported dataset formats and advanced SFTTrainer features, see the TRL documentation.
Load Model and Tokenizer
Configure LoRA Parameters
Prepare Your Dataset
Configure Training Settings
Initialize and Start Training
The dataset in this example uses conversational format (with messages
column), so SFTTrainer
automatically applies Jamba’s chat template. For more information about supported dataset formats and advanced SFTTrainer features, see the TRL documentation.
For Jamba Large LoRA fine-tuning, we recommend using the qLoRA+FSDP approach detailed in the QLoRA section below, as it provides better memory efficiency for the larger model.
QLoRA combines LoRA with 4-bit quantization, making it possible to fine-tune on a single 80GB GPU while maintaining good performance.
Before starting QLoRA fine-tuning, install the required dependencies:
Initialize Tokenizer and Configure Quantization
Load Model with Quantization
Configure LoRA Parameters
Prepare Your Dataset
Configure Training Settings
Initialize and Start Training
Initialize Tokenizer and Configure Quantization
Load Model with Quantization
Configure LoRA Parameters
Prepare Your Dataset
Configure Training Settings
Initialize and Start Training
Jamba Large fine-tuning requires 8x A100/H100 80GB GPUs and uses qLoRA+FSDP. This approach uses axolotl framework with a modified transformers version to optimize memory usage.
Due to its size, in order to run the training on a single 8 GPU node, Jamba Large 1.7 has to be quantized. This can happen either at the start of the training job, or in a pre-process step. If you want to pre-quantize the model, you can do that easily using bitsandbytes (make sure to use bnb_4bit_quant_storage=torch.bfloat16
so you can use FSDP).
Install Dependencies
Pre-quantize Model (Optional)
Run Training
For detailed configuration files and examples, visit the axolotl Jamba examples. The modified transformers version prevents excessive CPU RAM usage that would otherwise require over 1.6TB instead of the required 200GB.
Fine-tuning is the process of adapting a pre-trained model to perform better on specific tasks by training it on domain-specific data. Learn how to fine-tune Jamba models using different approaches including full fine-tuning, LoRA, and QLoRA
The Jamba models can be fine-tuned using several approaches:
Full fine-tuning updates all model parameters and provides the most comprehensive training results.
For a comprehensive implementation guide using AWS SageMaker with multi-node and FSDP configuration, see the AI21 SageMaker Fine-tuning Repository.
Full fine-tuning requires multiple high-memory GPUs.
LoRA (Low-Rank Adaptation) fine-tuning injects compact, low-rank adapter layers into a frozen pretrained model—letting you specialize it for your task with just a few percent of the parameters, minimal extra compute and storage and with a small loss in accuracy or inference speed.
Before starting LoRA fine-tuning, install the required dependencies:
This LoRA fine-tuning example uses bfloat16 precision and requires ~130GB GPU RAM (e.g., 2x A100 80GB GPUs).
Load Model and Tokenizer
Configure LoRA Parameters
Prepare Your Dataset
Configure Training Settings
Initialize and Start Training
The dataset in this example uses conversational format (with messages
column), so SFTTrainer
automatically applies Jamba’s chat template. For more information about supported dataset formats and advanced SFTTrainer features, see the TRL documentation.
Load Model and Tokenizer
Configure LoRA Parameters
Prepare Your Dataset
Configure Training Settings
Initialize and Start Training
The dataset in this example uses conversational format (with messages
column), so SFTTrainer
automatically applies Jamba’s chat template. For more information about supported dataset formats and advanced SFTTrainer features, see the TRL documentation.
For Jamba Large LoRA fine-tuning, we recommend using the qLoRA+FSDP approach detailed in the QLoRA section below, as it provides better memory efficiency for the larger model.
QLoRA combines LoRA with 4-bit quantization, making it possible to fine-tune on a single 80GB GPU while maintaining good performance.
Before starting QLoRA fine-tuning, install the required dependencies:
Initialize Tokenizer and Configure Quantization
Load Model with Quantization
Configure LoRA Parameters
Prepare Your Dataset
Configure Training Settings
Initialize and Start Training
Initialize Tokenizer and Configure Quantization
Load Model with Quantization
Configure LoRA Parameters
Prepare Your Dataset
Configure Training Settings
Initialize and Start Training
Jamba Large fine-tuning requires 8x A100/H100 80GB GPUs and uses qLoRA+FSDP. This approach uses axolotl framework with a modified transformers version to optimize memory usage.
Due to its size, in order to run the training on a single 8 GPU node, Jamba Large 1.7 has to be quantized. This can happen either at the start of the training job, or in a pre-process step. If you want to pre-quantize the model, you can do that easily using bitsandbytes (make sure to use bnb_4bit_quant_storage=torch.bfloat16
so you can use FSDP).
Install Dependencies
Pre-quantize Model (Optional)
Run Training
For detailed configuration files and examples, visit the axolotl Jamba examples. The modified transformers version prevents excessive CPU RAM usage that would otherwise require over 1.6TB instead of the required 200GB.