Overview

The Jamba models can be fine-tuned using several approaches:

  • Full Fine-tuning: Complete model parameter updates (requires significant GPU resources)
  • LoRA (Low-Rank Adaptation): Parameter-efficient fine-tuning approach
  • QLoRA: Combines LoRA with 4-bit quantization for single GPU training

Full Fine-tuning

Full fine-tuning updates all model parameters and provides the most comprehensive training results.

For a comprehensive implementation guide using AWS SageMaker with multi-node and FSDP configuration, see the AI21 SageMaker Fine-tuning Repository.

Full fine-tuning requires multiple high-memory GPUs.

LoRA Fine-tuning

LoRA (Low-Rank Adaptation) fine-tuning injects compact, low-rank adapter layers into a frozen pretrained model—letting you specialize it for your task with just a few percent of the parameters, minimal extra compute and storage and with a small loss in accuracy or inference speed.

Prerequisites

Before starting LoRA fine-tuning, install the required dependencies:

pip install trl transformers torch datasets peft

This LoRA fine-tuning example uses bfloat16 precision and requires ~130GB GPU RAM (e.g., 2x A100 80GB GPUs).

Implementation

1

Load Model and Tokenizer

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
from trl import SFTTrainer, SFTConfig
from peft import LoraConfig

tokenizer = AutoTokenizer.from_pretrained("ai21labs/AI21-Jamba-Mini-1.7")
model = AutoModelForCausalLM.from_pretrained(
    "ai21labs/AI21-Jamba-Mini-1.7",
    device_map="auto",  # Automatically distribute across available GPUs
    torch_dtype=torch.bfloat16,  # Use mixed precision for memory efficiency
    attn_implementation="flash_attention_2",  # Optimized attention implementation
)
2

Configure LoRA Parameters

lora_config = LoraConfig(
    r=8,  # Rank of adaptation - controls the number of trainable parameters
    target_modules=[
        "embed_tokens",
        "x_proj", "in_proj", "out_proj",  # mamba layers
        "gate_proj", "up_proj", "down_proj",  # mlp layers
        "q_proj", "k_proj", "v_proj", "o_proj",  # attention layers
    ],
    task_type="CAUSAL_LM",
    bias="none",
)
3

Prepare Your Dataset

# Load dataset (replace with your own dataset)
dataset = load_dataset("philschmid/dolly-15k-oai-style", split="train")
4

Configure Training Settings

training_args = SFTConfig(
    output_dir="/dev/shm/results",  # Where to save the model
    logging_dir="./logs",  # Where to save training logs
    num_train_epochs=2,  # Number of training epochs
    per_device_train_batch_size=4,  # Batch size per GPU
    learning_rate=1e-5,  # Learning rate for fine-tuning
    logging_steps=10,  # Log training metrics every 10 steps
    gradient_checkpointing=True,  # Save memory at cost of compute
    max_seq_length=4096,  # Maximum sequence length
    save_steps=100,  # Save model checkpoint every 100 steps
)
5

Initialize and Start Training

trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    peft_config=lora_config,
    train_dataset=dataset,
)

trainer.train()

The dataset in this example uses conversational format (with messages column), so SFTTrainer automatically applies Jamba’s chat template. For more information about supported dataset formats and advanced SFTTrainer features, see the TRL documentation.

QLoRA Fine-tuning

QLoRA combines LoRA with 4-bit quantization, making it possible to fine-tune on a single 80GB GPU while maintaining good performance.

Prerequisites

Before starting QLoRA fine-tuning, install the required dependencies:

pip install trl transformers torch datasets peft bitsandbytes

Implementation

1

Initialize Tokenizer and Configure Quantization

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from datasets import load_dataset
from trl import SFTTrainer, SFTConfig
from peft import LoraConfig

tokenizer = AutoTokenizer.from_pretrained("ai21labs/AI21-Jamba-Mini-1.7")

# Configure 4-bit quantization
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,  # Enable 4-bit quantization
    bnb_4bit_quant_type="nf4",  # Use NormalFloat 4-bit quantization
    bnb_4bit_compute_dtype=torch.bfloat16,  # Compute in bfloat16 for better stability
)
2

Load Model with Quantization

model = AutoModelForCausalLM.from_pretrained(
    "ai21labs/AI21-Jamba-Mini-1.7",
    device_map="auto",  # Automatically distribute across available GPUs
    quantization_config=quantization_config,  # Apply 4-bit quantization
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
)
3

Configure LoRA Parameters

lora_config = LoraConfig(
    r=8,  # Rank of adaptation - controls trainable parameters
    target_modules=[
        "embed_tokens", 
        "x_proj", "in_proj", "out_proj",  # mamba layers
        "gate_proj", "up_proj", "down_proj",  # mlp layers
        "q_proj", "k_proj", "v_proj", "o_proj",  # attention layers
    ],
    task_type="CAUSAL_LM",
    bias="none",
)
4

Prepare Your Dataset

# Load dataset (replace with your own dataset)
dataset = load_dataset("philschmid/dolly-15k-oai-style", split="train")
5

Configure Training Settings

training_args = SFTConfig(
    output_dir="./results",  # Where to save the model
    logging_dir="./logs",  # Where to save training logs
    num_train_epochs=2,  # Number of training epochs
    per_device_train_batch_size=8,  # Higher batch size possible with quantization
    learning_rate=1e-5,  # Learning rate for fine-tuning
    logging_steps=1,  # Log training metrics every step
    gradient_checkpointing=True,  # Save memory at cost of compute
    gradient_checkpointing_kwargs={"use_reentrant": False},  # Required for some models
    save_steps=100,  # Save model checkpoint every 100 steps
    max_seq_length=4096,  # Maximum sequence length
)
6

Initialize and Start Training

trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    peft_config=lora_config,
    train_dataset=dataset,
)

trainer.train()

Overview

The Jamba models can be fine-tuned using several approaches:

  • Full Fine-tuning: Complete model parameter updates (requires significant GPU resources)
  • LoRA (Low-Rank Adaptation): Parameter-efficient fine-tuning approach
  • QLoRA: Combines LoRA with 4-bit quantization for single GPU training

Full Fine-tuning

Full fine-tuning updates all model parameters and provides the most comprehensive training results.

For a comprehensive implementation guide using AWS SageMaker with multi-node and FSDP configuration, see the AI21 SageMaker Fine-tuning Repository.

Full fine-tuning requires multiple high-memory GPUs.

LoRA Fine-tuning

LoRA (Low-Rank Adaptation) fine-tuning injects compact, low-rank adapter layers into a frozen pretrained model—letting you specialize it for your task with just a few percent of the parameters, minimal extra compute and storage and with a small loss in accuracy or inference speed.

Prerequisites

Before starting LoRA fine-tuning, install the required dependencies:

pip install trl transformers torch datasets peft

This LoRA fine-tuning example uses bfloat16 precision and requires ~130GB GPU RAM (e.g., 2x A100 80GB GPUs).

Implementation

1

Load Model and Tokenizer

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
from trl import SFTTrainer, SFTConfig
from peft import LoraConfig

tokenizer = AutoTokenizer.from_pretrained("ai21labs/AI21-Jamba-Mini-1.7")
model = AutoModelForCausalLM.from_pretrained(
    "ai21labs/AI21-Jamba-Mini-1.7",
    device_map="auto",  # Automatically distribute across available GPUs
    torch_dtype=torch.bfloat16,  # Use mixed precision for memory efficiency
    attn_implementation="flash_attention_2",  # Optimized attention implementation
)
2

Configure LoRA Parameters

lora_config = LoraConfig(
    r=8,  # Rank of adaptation - controls the number of trainable parameters
    target_modules=[
        "embed_tokens",
        "x_proj", "in_proj", "out_proj",  # mamba layers
        "gate_proj", "up_proj", "down_proj",  # mlp layers
        "q_proj", "k_proj", "v_proj", "o_proj",  # attention layers
    ],
    task_type="CAUSAL_LM",
    bias="none",
)
3

Prepare Your Dataset

# Load dataset (replace with your own dataset)
dataset = load_dataset("philschmid/dolly-15k-oai-style", split="train")
4

Configure Training Settings

training_args = SFTConfig(
    output_dir="/dev/shm/results",  # Where to save the model
    logging_dir="./logs",  # Where to save training logs
    num_train_epochs=2,  # Number of training epochs
    per_device_train_batch_size=4,  # Batch size per GPU
    learning_rate=1e-5,  # Learning rate for fine-tuning
    logging_steps=10,  # Log training metrics every 10 steps
    gradient_checkpointing=True,  # Save memory at cost of compute
    max_seq_length=4096,  # Maximum sequence length
    save_steps=100,  # Save model checkpoint every 100 steps
)
5

Initialize and Start Training

trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    peft_config=lora_config,
    train_dataset=dataset,
)

trainer.train()

The dataset in this example uses conversational format (with messages column), so SFTTrainer automatically applies Jamba’s chat template. For more information about supported dataset formats and advanced SFTTrainer features, see the TRL documentation.

QLoRA Fine-tuning

QLoRA combines LoRA with 4-bit quantization, making it possible to fine-tune on a single 80GB GPU while maintaining good performance.

Prerequisites

Before starting QLoRA fine-tuning, install the required dependencies:

pip install trl transformers torch datasets peft bitsandbytes

Implementation

1

Initialize Tokenizer and Configure Quantization

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from datasets import load_dataset
from trl import SFTTrainer, SFTConfig
from peft import LoraConfig

tokenizer = AutoTokenizer.from_pretrained("ai21labs/AI21-Jamba-Mini-1.7")

# Configure 4-bit quantization
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,  # Enable 4-bit quantization
    bnb_4bit_quant_type="nf4",  # Use NormalFloat 4-bit quantization
    bnb_4bit_compute_dtype=torch.bfloat16,  # Compute in bfloat16 for better stability
)
2

Load Model with Quantization

model = AutoModelForCausalLM.from_pretrained(
    "ai21labs/AI21-Jamba-Mini-1.7",
    device_map="auto",  # Automatically distribute across available GPUs
    quantization_config=quantization_config,  # Apply 4-bit quantization
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
)
3

Configure LoRA Parameters

lora_config = LoraConfig(
    r=8,  # Rank of adaptation - controls trainable parameters
    target_modules=[
        "embed_tokens", 
        "x_proj", "in_proj", "out_proj",  # mamba layers
        "gate_proj", "up_proj", "down_proj",  # mlp layers
        "q_proj", "k_proj", "v_proj", "o_proj",  # attention layers
    ],
    task_type="CAUSAL_LM",
    bias="none",
)
4

Prepare Your Dataset

# Load dataset (replace with your own dataset)
dataset = load_dataset("philschmid/dolly-15k-oai-style", split="train")
5

Configure Training Settings

training_args = SFTConfig(
    output_dir="./results",  # Where to save the model
    logging_dir="./logs",  # Where to save training logs
    num_train_epochs=2,  # Number of training epochs
    per_device_train_batch_size=8,  # Higher batch size possible with quantization
    learning_rate=1e-5,  # Learning rate for fine-tuning
    logging_steps=1,  # Log training metrics every step
    gradient_checkpointing=True,  # Save memory at cost of compute
    gradient_checkpointing_kwargs={"use_reentrant": False},  # Required for some models
    save_steps=100,  # Save model checkpoint every 100 steps
    max_seq_length=4096,  # Maximum sequence length
)
6

Initialize and Start Training

trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    peft_config=lora_config,
    train_dataset=dataset,
)

trainer.train()